diff --git a/.gitignore b/.gitignore
index 58bcbf8..4be0b9d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,32 +1,33 @@
-# Windows image file caches
-Thumbs.db
-ehthumbs.db
-
-# Folder config file
-Desktop.ini
-
-# Recycle Bin used on file shares
-$RECYCLE.BIN/
-
-# Windows Installer files
-*.cab
-*.msi
-*.msm
-*.msp
-
-# =========================
-# Operating System Files
-# =========================
-
-# OSX
-# =========================
-
+# Windows image file caches
+Thumbs.db
+ehthumbs.db
+
+# Folder config file
+Desktop.ini
+
+# Recycle Bin used on file shares
+$RECYCLE.BIN/
+
+# Windows Installer files
+*.cab
+*.msi
+*.msm
+*.msp
+
+# =========================
+# Operating System Files
+# =========================
+
+# OSX
+# =========================
+
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
-Icon
+Icon
+
# Thumbnails
._*
@@ -41,3 +42,7 @@ Icon
Network Trash Folder
Temporary Items
.apdisk
+
+# Intellij idea
+.idea
+out
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/AssociationAnalysis/DataMining_Apriori/AprioriTool.java b/AssociationAnalysis/DataMining_Apriori/AprioriTool.java
index 2cd08b4..972872c 100644
--- a/AssociationAnalysis/DataMining_Apriori/AprioriTool.java
+++ b/AssociationAnalysis/DataMining_Apriori/AprioriTool.java
@@ -1,432 +1,417 @@
-package DataMining_Apriori;
+package AssociationAnalysis.DataMining_Apriori;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
+import java.util.*;
/**
* apriori算法工具类
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class AprioriTool {
- // 最小支持度计数
- private int minSupportCount;
- // 测试数据文件地址
- private String filePath;
- // 每个事务中的商品ID
- private ArrayList totalGoodsIDs;
- // 过程中计算出来的所有频繁项集列表
- private ArrayList resultItem;
- // 过程中计算出来频繁项集的ID集合
- private ArrayList resultItemID;
-
- public AprioriTool(String filePath, int minSupportCount) {
- this.filePath = filePath;
- this.minSupportCount = minSupportCount;
- readDataFile();
- }
-
- /**
- * 从文件中读取数据
- */
- private void readDataFile() {
- File file = new File(filePath);
- ArrayList dataArray = new ArrayList();
-
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
-
- String[] temp = null;
- totalGoodsIDs = new ArrayList<>();
- for (String[] array : dataArray) {
- temp = new String[array.length - 1];
- System.arraycopy(array, 1, temp, 0, array.length - 1);
-
- // 将事务ID加入列表吧中
- totalGoodsIDs.add(temp);
- }
- }
-
- /**
- * 判读字符数组array2是否包含于数组array1中
- *
- * @param array1
- * @param array2
- * @return
- */
- public boolean iSStrContain(String[] array1, String[] array2) {
- if (array1 == null || array2 == null) {
- return false;
- }
-
- boolean iSContain = false;
- for (String s : array2) {
- // 新的字母比较时,重新初始化变量
- iSContain = false;
- // 判读array2中每个字符,只要包括在array1中 ,就算包含
- for (String s2 : array1) {
- if (s.equals(s2)) {
- iSContain = true;
- break;
- }
- }
-
- // 如果已经判断出不包含了,则直接中断循环
- if (!iSContain) {
- break;
- }
- }
-
- return iSContain;
- }
-
- /**
- * 项集进行连接运算
- */
- private void computeLink() {
- // 连接计算的终止数,k项集必须算到k-1子项集为止
- int endNum = 0;
- // 当前已经进行连接运算到几项集,开始时就是1项集
- int currentNum = 1;
- // 商品,1频繁项集映射图
- HashMap itemMap = new HashMap<>();
- FrequentItem tempItem;
- // 初始列表
- ArrayList list = new ArrayList<>();
- // 经过连接运算后产生的结果项集
- resultItem = new ArrayList<>();
- resultItemID = new ArrayList<>();
- // 商品ID的种类
- ArrayList idType = new ArrayList<>();
- for (String[] a : totalGoodsIDs) {
- for (String s : a) {
- if (!idType.contains(s)) {
- tempItem = new FrequentItem(new String[] { s }, 1);
- idType.add(s);
- resultItemID.add(new String[] { s });
- } else {
- // 支持度计数加1
- tempItem = itemMap.get(s);
- tempItem.setCount(tempItem.getCount() + 1);
- }
- itemMap.put(s, tempItem);
- }
- }
- // 将初始频繁项集转入到列表中,以便继续做连接运算
- for (Map.Entry entry : itemMap.entrySet()) {
- list.add((FrequentItem) entry.getValue());
- }
- // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少
- Collections.sort(list);
- resultItem.addAll(list);
-
- String[] array1;
- String[] array2;
- String[] resultArray;
- ArrayList tempIds;
- ArrayList resultContainer;
- // 总共要算到endNum项集
- endNum = list.size() - 1;
-
- while (currentNum < endNum) {
- resultContainer = new ArrayList<>();
- for (int i = 0; i < list.size() - 1; i++) {
- tempItem = list.get(i);
- array1 = tempItem.getIdArray();
- for (int j = i + 1; j < list.size(); j++) {
- tempIds = new ArrayList<>();
- array2 = list.get(j).getIdArray();
- for (int k = 0; k < array1.length; k++) {
- // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作
- if (array1[k].equals(array2[k])) {
- tempIds.add(array1[k]);
- } else {
- tempIds.add(array1[k]);
- tempIds.add(array2[k]);
- }
- }
- resultArray = new String[tempIds.size()];
- tempIds.toArray(resultArray);
-
- boolean isContain = false;
- // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的
- if (resultArray.length == (array1.length + 1)) {
- isContain = isIDArrayContains(resultContainer,
- resultArray);
- if (!isContain) {
- resultContainer.add(resultArray);
- }
- }
- }
- }
-
- // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集
- list = cutItem(resultContainer);
- currentNum++;
- }
-
- // 输出频繁项集
- for (int k = 1; k <= currentNum; k++) {
- System.out.println("频繁" + k + "项集:");
- for (FrequentItem i : resultItem) {
- if (i.getLength() == k) {
- System.out.print("{");
- for (String t : i.getIdArray()) {
- System.out.print(t + ",");
- }
- System.out.print("},");
- }
- }
- System.out.println();
- }
- }
-
- /**
- * 判断列表结果中是否已经包含此数组
- *
- * @param container
- * ID数组容器
- * @param array
- * 待比较数组
- * @return
- */
- private boolean isIDArrayContains(ArrayList container,
- String[] array) {
- boolean isContain = true;
- if (container.size() == 0) {
- isContain = false;
- return isContain;
- }
-
- for (String[] s : container) {
- // 比较的视乎必须保证长度一样
- if (s.length != array.length) {
- continue;
- }
-
- isContain = true;
- for (int i = 0; i < s.length; i++) {
- // 只要有一个id不等,就算不相等
- if (s[i] != array[i]) {
- isContain = false;
- break;
- }
- }
-
- // 如果已经判断是包含在容器中时,直接退出
- if (isContain) {
- break;
- }
- }
-
- return isContain;
- }
-
- /**
- * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集
- */
- private ArrayList cutItem(ArrayList resultIds) {
- String[] temp;
- // 忽略的索引位置,以此构建子集
- int igNoreIndex = 0;
- FrequentItem tempItem;
- // 剪枝生成新的频繁项集
- ArrayList newItem = new ArrayList<>();
- // 不符合要求的id
- ArrayList deleteIdArray = new ArrayList<>();
- // 子项集是否也为频繁子项集
- boolean isContain = true;
-
- for (String[] array : resultIds) {
- // 列举出其中的一个个的子项集,判断存在于频繁项集列表中
- temp = new String[array.length - 1];
- for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) {
- isContain = true;
- for (int j = 0, k = 0; j < array.length; j++) {
- if (j != igNoreIndex) {
- temp[k] = array[j];
- k++;
- }
- }
-
- if (!isIDArrayContains(resultItemID, temp)) {
- isContain = false;
- break;
- }
- }
-
- if (!isContain) {
- deleteIdArray.add(array);
- }
- }
-
- // 移除不符合条件的ID组合
- resultIds.removeAll(deleteIdArray);
-
- // 移除支持度计数不够的id集合
- int tempCount = 0;
- for (String[] array : resultIds) {
- tempCount = 0;
- for (String[] array2 : totalGoodsIDs) {
- if (isStrArrayContain(array2, array)) {
- tempCount++;
- }
- }
-
- // 如果支持度计数大于等于最小最小支持度计数则生成新的频繁项集,并加入结果集中
- if (tempCount >= minSupportCount) {
- tempItem = new FrequentItem(array, tempCount);
- newItem.add(tempItem);
- resultItemID.add(array);
- resultItem.add(tempItem);
- }
- }
-
- return newItem;
- }
-
- /**
- * 数组array2是否包含于array1中,不需要完全一样
- *
- * @param array1
- * @param array2
- * @return
- */
- private boolean isStrArrayContain(String[] array1, String[] array2) {
- boolean isContain = true;
- for (String s2 : array2) {
- isContain = false;
- for (String s1 : array1) {
- // 只要s2字符存在于array1中,这个字符就算包含在array1中
- if (s2.equals(s1)) {
- isContain = true;
- break;
- }
- }
-
- // 一旦发现不包含的字符,则array2数组不包含于array1中
- if (!isContain) {
- break;
- }
- }
-
- return isContain;
- }
-
- /**
- * 根据产生的频繁项集输出关联规则
- *
- * @param minConf
- * 最小置信度阈值
- */
- public void printAttachRule(double minConf) {
- // 进行连接和剪枝操作
- computeLink();
-
- int count1 = 0;
- int count2 = 0;
- ArrayList childGroup1;
- ArrayList childGroup2;
- String[] group1;
- String[] group2;
- // 以最后一个频繁项集做关联规则的输出
- String[] array = resultItem.get(resultItem.size() - 1).getIdArray();
- // 子集总数,计算的时候除去自身和空集
- int totalNum = (int) Math.pow(2, array.length);
- String[] temp;
- // 二进制数组,用来代表各个子集
- int[] binaryArray;
- // 除去头和尾部
- for (int i = 1; i < totalNum - 1; i++) {
- binaryArray = new int[array.length];
- numToBinaryArray(binaryArray, i);
-
- childGroup1 = new ArrayList<>();
- childGroup2 = new ArrayList<>();
- count1 = 0;
- count2 = 0;
- // 按照二进制位关系取出子集
- for (int j = 0; j < binaryArray.length; j++) {
- if (binaryArray[j] == 1) {
- childGroup1.add(array[j]);
- } else {
- childGroup2.add(array[j]);
- }
- }
-
- group1 = new String[childGroup1.size()];
- group2 = new String[childGroup2.size()];
-
- childGroup1.toArray(group1);
- childGroup2.toArray(group2);
-
- for (String[] a : totalGoodsIDs) {
- if (isStrArrayContain(a, group1)) {
- count1++;
-
- // 在group1的条件下,统计group2的事件发生次数
- if (isStrArrayContain(a, group2)) {
- count2++;
- }
- }
- }
-
- // {A}-->{B}的意思为在A的情况下发生B的概率
- System.out.print("{");
- for (String s : group1) {
- System.out.print(s + ", ");
- }
- System.out.print("}-->");
- System.out.print("{");
- for (String s : group2) {
- System.out.print(s + ", ");
- }
- System.out.print(MessageFormat.format(
- "},confidence(置信度):{0}/{1}={2}", count2, count1, count2
- * 1.0 / count1));
- if (count2 * 1.0 / count1 < minConf) {
- // 不符合要求,不是强规则
- System.out.println("由于此规则置信度未达到最小置信度的要求,不是强规则");
- } else {
- System.out.println("为强规则");
- }
- }
-
- }
-
- /**
- * 数字转为二进制形式
- *
- * @param binaryArray
- * 转化后的二进制数组形式
- * @param num
- * 待转化数字
- */
- private void numToBinaryArray(int[] binaryArray, int num) {
- int index = 0;
- while (num != 0) {
- binaryArray[index] = num % 2;
- index++;
- num /= 2;
- }
- }
+class AprioriTool {
+ // 最小支持度计数
+ private int minSupportCount;
+ // 测试数据文件地址
+ private String filePath;
+ // 每个事务中的商品ID
+ private ArrayList totalGoodsIDs;
+ // 过程中计算出来的所有频繁项集列表
+ private ArrayList resultItem;
+ // 过程中计算出来频繁项集的ID集合
+ private ArrayList resultItemID;
+
+ AprioriTool(String filePath, int minSupportCount){
+ this.filePath = filePath;
+ this.minSupportCount = minSupportCount;
+ readDataFile();
+ }
+
+ /**
+ * 从文件中读取数据
+ */
+ private void readDataFile(){
+ File file = new File(filePath);
+ ArrayList dataArray = new ArrayList<>();
+
+ try {
+ BufferedReader in = new BufferedReader(new FileReader(file));
+ String str;
+ String[] tempArray;
+ while ((str = in.readLine()) != null) {
+ tempArray = str.split(" ");
+ dataArray.add(tempArray);
+ }
+ in.close();
+ } catch (IOException e) {
+ e.getStackTrace();
+ }
+
+ String[] temp;
+ totalGoodsIDs = new ArrayList<>();
+ for (String[] array : dataArray) {
+ temp = new String[array.length - 1];
+ System.arraycopy(array, 1, temp, 0, array.length - 1);
+
+ // 将事务ID加入列表吧中
+ totalGoodsIDs.add(temp);
+ }
+ }
+
+ /**
+ * 判读字符数组array2是否包含于数组array1中
+ *
+ * @param array1 字符数组1
+ * @param array2 字符数组2
+ */
+ public boolean iSStrContain(String[] array1, String[] array2){
+ if (array1 == null || array2 == null) {
+ return false;
+ }
+
+ boolean iSContain = false;
+ for (String s : array2) {
+ // 新的字母比较时,重新初始化变量
+ iSContain = false;
+ // 判读array2中每个字符,只要包括在array1中 ,就算包含
+ for (String s2 : array1) {
+ if (s.equals(s2)) {
+ iSContain = true;
+ break;
+ }
+ }
+
+ // 如果已经判断出不包含了,则直接中断循环
+ if (!iSContain) {
+ break;
+ }
+ }
+
+ return iSContain;
+ }
+
+ /**
+ * 项集进行连接运算
+ */
+ private void computeLink(){
+ // 连接计算的终止数,k项集必须算到k-1子项集为止
+ int endNum;
+ // 当前已经进行连接运算到几项集,开始时就是1项集
+ int currentNum = 1;
+ // 商品,1频繁项集映射图
+ HashMap itemMap = new HashMap<>();
+ FrequentItem tempItem;
+ // 初始列表
+ ArrayList list = new ArrayList<>();
+ // 经过连接运算后产生的结果项集
+ resultItem = new ArrayList<>();
+ resultItemID = new ArrayList<>();
+ // 商品ID的种类
+ ArrayList idType = new ArrayList<>();
+ for (String[] a : totalGoodsIDs) {
+ for (String s : a) {
+ if (!idType.contains(s)) {
+ tempItem = new FrequentItem(new String[]{s}, 1);
+ idType.add(s);
+ resultItemID.add(new String[]{s});
+ } else {
+ // 支持度计数加1
+ tempItem = itemMap.get(s);
+ tempItem.setCount(tempItem.getCount() + 1);
+ }
+ itemMap.put(s, tempItem);
+ }
+ }
+ // 将初始频繁项集转入到列表中,以便继续做连接运算
+ for (Map.Entry entry : itemMap.entrySet()) {
+ list.add((FrequentItem) entry.getValue());
+ }
+ // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少
+ Collections.sort(list);
+ resultItem.addAll(list);
+
+ String[] array1;
+ String[] array2;
+ String[] resultArray;
+ ArrayList tempIds;
+ ArrayList resultContainer;
+ // 总共要算到endNum项集
+ endNum = list.size() - 1;
+
+ while (currentNum < endNum) {
+ resultContainer = new ArrayList<>();
+ for (int i = 0; i < list.size() - 1; i++) {
+ tempItem = list.get(i);
+ array1 = tempItem.getIdArray();
+ for (int j = i + 1; j < list.size(); j++) {
+ tempIds = new ArrayList<>();
+ array2 = list.get(j).getIdArray();
+ for (int k = 0; k < array1.length; k++) {
+ // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作
+ if (array1[k].equals(array2[k])) {
+ tempIds.add(array1[k]);
+ } else {
+ tempIds.add(array1[k]);
+ tempIds.add(array2[k]);
+ }
+ }
+ resultArray = new String[tempIds.size()];
+ tempIds.toArray(resultArray);
+
+ boolean isContain;
+ // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的
+ if (resultArray.length == (array1.length + 1)) {
+ isContain = isIDArrayContains(resultContainer,
+ resultArray);
+ if (!isContain) {
+ resultContainer.add(resultArray);
+ }
+ }
+ }
+ }
+
+ // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集
+ list = cutItem(resultContainer);
+ currentNum++;
+ }
+
+ // 输出频繁项集
+ for (int k = 1; k <= currentNum; k++) {
+ System.out.println("频繁" + k + "项集:");
+ for (FrequentItem i : resultItem) {
+ if (i.getLength() == k) {
+ System.out.print("{");
+ for (String t : i.getIdArray()) {
+ System.out.print(t + ",");
+ }
+ System.out.print("},");
+ }
+ }
+ System.out.println();
+ }
+ }
+
+ /**
+ * 判断列表结果中是否已经包含此数组
+ *
+ * @param container ID数组容器
+ * @param array 待比较数组
+ */
+ private boolean isIDArrayContains(ArrayList container,
+ String[] array){
+ boolean isContain = true;
+ if (container.size() == 0) {
+ return false;
+ }
+
+ for (String[] s : container) {
+ // 比较的视乎必须保证长度一样
+ if (s.length != array.length) {
+ continue;
+ }
+
+ isContain = true;
+ for (int i = 0; i < s.length; i++) {
+ // 只要有一个id不等,就算不相等
+ if (!Objects.equals(s[i], array[i])) {
+ isContain = false;
+ break;
+ }
+ }
+
+ // 如果已经判断是包含在容器中时,直接退出
+ if (isContain) {
+ break;
+ }
+ }
+
+ return isContain;
+ }
+
+ /**
+ * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集
+ */
+ private ArrayList cutItem(ArrayList resultIds){
+ String[] temp;
+ // 忽略的索引位置,以此构建子集
+ int igNoreIndex;
+ FrequentItem tempItem;
+ // 剪枝生成新的频繁项集
+ ArrayList newItem = new ArrayList<>();
+ // 不符合要求的id
+ ArrayList deleteIdArray = new ArrayList<>();
+ // 子项集是否也为频繁子项集
+ boolean isContain = true;
+
+ for (String[] array : resultIds) {
+ // 列举出其中的一个个的子项集,判断存在于频繁项集列表中
+ temp = new String[array.length - 1];
+ for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) {
+ isContain = true;
+ for (int j = 0, k = 0; j < array.length; j++) {
+ if (j != igNoreIndex) {
+ temp[k] = array[j];
+ k++;
+ }
+ }
+
+ if (!isIDArrayContains(resultItemID, temp)) {
+ isContain = false;
+ break;
+ }
+ }
+
+ if (!isContain) {
+ deleteIdArray.add(array);
+ }
+ }
+
+ // 移除不符合条件的ID组合
+ resultIds.removeAll(deleteIdArray);
+
+ // 移除支持度计数不够的id集合
+ int tempCount;
+ for (String[] array : resultIds) {
+ tempCount = 0;
+ for (String[] array2 : totalGoodsIDs) {
+ if (isStrArrayContain(array2, array)) {
+ tempCount++;
+ }
+ }
+
+ // 如果支持度计数大于等于最小最小支持度计数则生成新的频繁项集,并加入结果集中
+ if (tempCount >= minSupportCount) {
+ tempItem = new FrequentItem(array, tempCount);
+ newItem.add(tempItem);
+ resultItemID.add(array);
+ resultItem.add(tempItem);
+ }
+ }
+
+ return newItem;
+ }
+
+ /**
+ * 数组array2是否包含于array1中,不需要完全一样
+ *
+ * @param array1
+ * @param array2
+ */
+ private boolean isStrArrayContain(String[] array1, String[] array2){
+ boolean isContain = true;
+ for (String s2 : array2) {
+ isContain = false;
+ for (String s1 : array1) {
+ // 只要s2字符存在于array1中,这个字符就算包含在array1中
+ if (s2.equals(s1)) {
+ isContain = true;
+ break;
+ }
+ }
+
+ // 一旦发现不包含的字符,则array2数组不包含于array1中
+ if (!isContain) {
+ break;
+ }
+ }
+
+ return isContain;
+ }
+
+ /**
+ * 根据产生的频繁项集输出关联规则
+ *
+ * @param minConf 最小置信度阈值
+ */
+ void printAttachRule(double minConf){
+ // 进行连接和剪枝操作
+ computeLink();
+
+ int count1;
+ int count2;
+ ArrayList childGroup1;
+ ArrayList childGroup2;
+ String[] group1;
+ String[] group2;
+ // 以最后一个频繁项集做关联规则的输出
+ String[] array = resultItem.get(resultItem.size() - 1).getIdArray();
+ // 子集总数,计算的时候除去自身和空集
+ int totalNum = (int) Math.pow(2, array.length);
+ // 二进制数组,用来代表各个子集
+ int[] binaryArray;
+ // 除去头和尾部
+ for (int i = 1; i < totalNum - 1; i++) {
+ binaryArray = new int[array.length];
+ numToBinaryArray(binaryArray, i);
+
+ childGroup1 = new ArrayList<>();
+ childGroup2 = new ArrayList<>();
+ count1 = 0;
+ count2 = 0;
+ // 按照二进制位关系取出子集
+ for (int j = 0; j < binaryArray.length; j++) {
+ if (binaryArray[j] == 1) {
+ childGroup1.add(array[j]);
+ } else {
+ childGroup2.add(array[j]);
+ }
+ }
+
+ group1 = new String[childGroup1.size()];
+ group2 = new String[childGroup2.size()];
+
+ childGroup1.toArray(group1);
+ childGroup2.toArray(group2);
+
+ for (String[] a : totalGoodsIDs) {
+ if (isStrArrayContain(a, group1)) {
+ count1++;
+
+ // 在group1的条件下,统计group2的事件发生次数
+ if (isStrArrayContain(a, group2)) {
+ count2++;
+ }
+ }
+ }
+
+ // {A}-->{B}的意思为在A的情况下发生B的概率
+ System.out.print("{");
+ for (String s : group1) {
+ System.out.print(s + ", ");
+ }
+ System.out.print("}-->");
+ System.out.print("{");
+ for (String s : group2) {
+ System.out.print(s + ", ");
+ }
+ System.out.print(MessageFormat.format(
+ "},confidence(置信度):{0}/{1}={2}", count2, count1, count2
+ * 1.0 / count1));
+ if (count2 * 1.0 / count1 < minConf) {
+ // 不符合要求,不是强规则
+ System.out.println("由于此规则置信度未达到最小置信度的要求,不是强规则");
+ } else {
+ System.out.println("为强规则");
+ }
+ }
+ }
+
+ /**
+ * 数字转为二进制形式
+ *
+ * @param binaryArray 转化后的二进制数组形式
+ * @param num 待转化数字
+ */
+ private void numToBinaryArray(int[] binaryArray, int num){
+ int index = 0;
+ while (num != 0) {
+ binaryArray[index] = num % 2;
+ index++;
+ num /= 2;
+ }
+ }
}
diff --git a/AssociationAnalysis/DataMining_Apriori/Client.java b/AssociationAnalysis/DataMining_Apriori/Client.java
index 7791c6f..726ae8d 100644
--- a/AssociationAnalysis/DataMining_Apriori/Client.java
+++ b/AssociationAnalysis/DataMining_Apriori/Client.java
@@ -1,15 +1,15 @@
-package DataMining_Apriori;
+package AssociationAnalysis.DataMining_Apriori;
/**
* apriori关联规则挖掘算法调用类
- * @author lyq
*
+ * @author Qstar
*/
public class Client {
- public static void main(String[] args){
- String filePath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt";
-
- AprioriTool tool = new AprioriTool(filePath, 2);
- tool.printAttachRule(0.7);
- }
+ public static void main(String[] args){
+ String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/AssociationAnalysis/DataMining_Apriori/testInput.txt";
+
+ AprioriTool tool = new AprioriTool(filePath, 2);
+ tool.printAttachRule(0.7);
+ }
}
diff --git a/AssociationAnalysis/DataMining_Apriori/FrequentItem.java b/AssociationAnalysis/DataMining_Apriori/FrequentItem.java
index 592d40d..bd3b9ab 100644
--- a/AssociationAnalysis/DataMining_Apriori/FrequentItem.java
+++ b/AssociationAnalysis/DataMining_Apriori/FrequentItem.java
@@ -1,56 +1,51 @@
-package DataMining_Apriori;
+package AssociationAnalysis.DataMining_Apriori;
/**
* 频繁项集
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class FrequentItem implements Comparable{
- // 频繁项集的集合ID
- private String[] idArray;
- // 频繁项集的支持度计数
- private int count;
- //频繁项集的长度,1项集或是2项集,亦或是3项集
- private int length;
-
- public FrequentItem(String[] idArray, int count){
- this.idArray = idArray;
- this.count = count;
- length = idArray.length;
- }
-
- public String[] getIdArray() {
- return idArray;
- }
-
- public void setIdArray(String[] idArray) {
- this.idArray = idArray;
- }
-
- public int getCount() {
- return count;
- }
-
- public void setCount(int count) {
- this.count = count;
- }
-
- public int getLength() {
- return length;
- }
-
- public void setLength(int length) {
- this.length = length;
- }
-
- @Override
- public int compareTo(FrequentItem o) {
- // TODO Auto-generated method stub
- Integer int1 = Integer.parseInt(this.getIdArray()[0]);
- Integer int2 = Integer.parseInt(o.getIdArray()[0]);
-
- return int1.compareTo(int2);
- }
-
+public class FrequentItem implements Comparable {
+ // 频繁项集的集合ID
+ private String[] idArray;
+ // 频繁项集的支持度计数
+ private int count;
+ //频繁项集的长度,1项集或是2项集,亦或是3项集
+ private int length;
+
+ public FrequentItem(String[] idArray, int count){
+ this.idArray = idArray;
+ this.count = count;
+ length = idArray.length;
+ }
+
+ public String[] getIdArray(){
+ return idArray;
+ }
+
+ public int getCount(){
+ return count;
+ }
+
+ public void setCount(int count){
+ this.count = count;
+ }
+
+ public int getLength(){
+ return length;
+ }
+
+ public void setLength(int length){
+ this.length = length;
+ }
+
+ @Override
+ public int compareTo(FrequentItem o){
+ // TODO Auto-generated method stub
+ Integer int1 = Integer.parseInt(this.getIdArray()[0]);
+ Integer int2 = Integer.parseInt(o.getIdArray()[0]);
+
+ return int1.compareTo(int2);
+ }
+
}
diff --git a/AssociationAnalysis/DataMining_FPTree/Client.java b/AssociationAnalysis/DataMining_FPTree/Client.java
index 69789b7..fe4f10d 100644
--- a/AssociationAnalysis/DataMining_FPTree/Client.java
+++ b/AssociationAnalysis/DataMining_FPTree/Client.java
@@ -1,17 +1,17 @@
-package DataMining_FPTree;
+package AssociationAnalysis.DataMining_FPTree;
/**
* FPTree频繁模式树算法
- * @author lyq
*
+ * @author Qstar
*/
public class Client {
- public static void main(String[] args){
- String filePath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt";
- //最小支持度阈值
- int minSupportCount = 2;
-
- FPTreeTool tool = new FPTreeTool(filePath, minSupportCount);
- tool.startBuildingTree();
- }
+ public static void main(String[] args){
+ String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/AssociationAnalysis/DataMining_FPTree/testInput.txt";
+ //最小支持度阈值
+ int minSupportCount = 2;
+
+ FPTreeTool tool = new FPTreeTool(filePath, minSupportCount);
+ tool.startBuildingTree();
+ }
}
diff --git a/AssociationAnalysis/DataMining_FPTree/FPTreeTool.java b/AssociationAnalysis/DataMining_FPTree/FPTreeTool.java
index e53a946..3175955 100644
--- a/AssociationAnalysis/DataMining_FPTree/FPTreeTool.java
+++ b/AssociationAnalysis/DataMining_FPTree/FPTreeTool.java
@@ -1,4 +1,4 @@
-package DataMining_FPTree;
+package AssociationAnalysis.DataMining_FPTree;
import java.io.BufferedReader;
import java.io.File;
@@ -11,453 +11,439 @@
/**
* FPTree算法工具类
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class FPTreeTool {
- // 输入数据文件位置
- private String filePath;
- // 最小支持度阈值
- private int minSupportCount;
- // 所有事物ID记录
- private ArrayList totalGoodsID;
- // 各个ID的统计数目映射表项,计数用于排序使用
- private HashMap itemCountMap;
-
- public FPTreeTool(String filePath, int minSupportCount) {
- this.filePath = filePath;
- this.minSupportCount = minSupportCount;
- readDataFile();
- }
-
- /**
- * 从文件中读取数据
- */
- private void readDataFile() {
- File file = new File(filePath);
- ArrayList dataArray = new ArrayList();
-
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
-
- String[] temp;
- int count = 0;
- itemCountMap = new HashMap<>();
- totalGoodsID = new ArrayList<>();
- for (String[] a : dataArray) {
- temp = new String[a.length - 1];
- System.arraycopy(a, 1, temp, 0, a.length - 1);
- totalGoodsID.add(temp);
- for (String s : temp) {
- if (!itemCountMap.containsKey(s)) {
- count = 1;
- } else {
- count = ((int) itemCountMap.get(s));
- // 支持度计数加1
- count++;
- }
- // 更新表项
- itemCountMap.put(s, count);
- }
- }
- }
-
- /**
- * 根据事物记录构造FP树
- */
- private void buildFPTree(ArrayList suffixPattern,
- ArrayList> transctionList) {
- // 设置一个空根节点
- TreeNode rootNode = new TreeNode(null, 0);
- int count = 0;
- // 节点是否存在
- boolean isExist = false;
- ArrayList childNodes;
- ArrayList pathList;
- // 相同类型节点链表,用于构造的新的FP树
- HashMap> linkedNode = new HashMap<>();
- HashMap countNode = new HashMap<>();
- // 根据事物记录,一步步构建FP树
- for (ArrayList array : transctionList) {
- TreeNode searchedNode;
- pathList = new ArrayList<>();
- for (TreeNode node : array) {
- pathList.add(node);
- nodeCounted(node, countNode);
- searchedNode = searchNode(rootNode, pathList);
- childNodes = searchedNode.getChildNodes();
-
- if (childNodes == null) {
- childNodes = new ArrayList<>();
- childNodes.add(node);
- searchedNode.setChildNodes(childNodes);
- node.setParentNode(searchedNode);
- nodeAddToLinkedList(node, linkedNode);
- } else {
- isExist = false;
- for (TreeNode node2 : childNodes) {
- // 如果找到名称相同,则更新支持度计数
- if (node.getName().equals(node2.getName())) {
- count = node2.getCount() + node.getCount();
- node2.setCount(count);
- // 标识已找到节点位置
- isExist = true;
- break;
- }
- }
-
- if (!isExist) {
- // 如果没有找到,需添加子节点
- childNodes.add(node);
- node.setParentNode(searchedNode);
- nodeAddToLinkedList(node, linkedNode);
- }
- }
-
- }
- }
-
- // 如果FP树已经是单条路径,则输出此时的频繁模式
- if (isSinglePath(rootNode)) {
- printFrequentPattern(suffixPattern, rootNode);
- System.out.println("-------");
- } else {
- ArrayList> tList;
- ArrayList sPattern;
- if (suffixPattern == null) {
- sPattern = new ArrayList<>();
- } else {
- // 进行一个拷贝,避免互相引用的影响
- sPattern = (ArrayList) suffixPattern.clone();
- }
-
- // 利用节点链表构造新的事务
- for (Map.Entry entry : countNode.entrySet()) {
- // 添加到后缀模式中
- sPattern.add((String) entry.getKey());
- //获取到了条件模式机,作为新的事务
- tList = getTransactionList((String) entry.getKey(), linkedNode);
-
- System.out.print("[后缀模式]:{");
- for(String s: sPattern){
- System.out.print(s + ", ");
- }
- System.out.print("}, 此时的条件模式基:");
- for(ArrayList tnList: tList){
- System.out.print("{");
- for(TreeNode n: tnList){
- System.out.print(n.getName() + ", ");
- }
- System.out.print("}, ");
- }
- System.out.println();
- // 递归构造FP树
- buildFPTree(sPattern, tList);
- // 再次移除此项,构造不同的后缀模式,防止对后面造成干扰
- sPattern.remove((String) entry.getKey());
- }
- }
- }
-
- /**
- * 将节点加入到同类型节点的链表中
- *
- * @param node
- * 待加入节点
- * @param linkedList
- * 链表图
- */
- private void nodeAddToLinkedList(TreeNode node,
- HashMap> linkedList) {
- String name = node.getName();
- ArrayList list;
-
- if (linkedList.containsKey(name)) {
- list = linkedList.get(name);
- // 将node添加到此队列中
- list.add(node);
- } else {
- list = new ArrayList<>();
- list.add(node);
- linkedList.put(name, list);
- }
- }
-
- /**
- * 根据链表构造出新的事务
- *
- * @param name
- * 节点名称
- * @param linkedList
- * 链表
- * @return
- */
- private ArrayList> getTransactionList(String name,
- HashMap> linkedList) {
- ArrayList> tList = new ArrayList<>();
- ArrayList targetNode = linkedList.get(name);
- ArrayList singleTansaction;
- TreeNode temp;
-
- for (TreeNode node : targetNode) {
- singleTansaction = new ArrayList<>();
-
- temp = node;
- while (temp.getParentNode().getName() != null) {
- temp = temp.getParentNode();
- singleTansaction.add(new TreeNode(temp.getName(), 1));
- }
-
- // 按照支持度计数得反转一下
- Collections.reverse(singleTansaction);
-
- for (TreeNode node2 : singleTansaction) {
- // 支持度计数调成与模式后缀一样
- node2.setCount(node.getCount());
- }
-
- if (singleTansaction.size() > 0) {
- tList.add(singleTansaction);
- }
- }
-
- return tList;
- }
-
- /**
- * 节点计数
- *
- * @param node
- * 待加入节点
- * @param nodeCount
- * 计数映射图
- */
- private void nodeCounted(TreeNode node, HashMap nodeCount) {
- int count = 0;
- String name = node.getName();
-
- if (nodeCount.containsKey(name)) {
- count = nodeCount.get(name);
- count++;
- } else {
- count = 1;
- }
-
- nodeCount.put(name, count);
- }
-
- /**
- * 显示决策树
- *
- * @param node
- * 待显示的节点
- * @param blankNum
- * 行空格符,用于显示树型结构
- */
- private void showFPTree(TreeNode node, int blankNum) {
- System.out.println();
- for (int i = 0; i < blankNum; i++) {
- System.out.print("\t");
- }
- System.out.print("--");
- System.out.print("--");
-
- if (node.getChildNodes() == null) {
- System.out.print("[");
- System.out.print("I" + node.getName() + ":" + node.getCount());
- System.out.print("]");
- } else {
- // 递归显示子节点
- // System.out.print("【" + node.getName() + "】");
- for (TreeNode childNode : node.getChildNodes()) {
- showFPTree(childNode, 2 * blankNum);
- }
- }
-
- }
-
- /**
- * 待插入节点的抵达位置节点,从根节点开始向下寻找待插入节点的位置
- *
- * @param root
- * @param list
- * @return
- */
- private TreeNode searchNode(TreeNode node, ArrayList list) {
- ArrayList pathList = new ArrayList<>();
- TreeNode tempNode = null;
- TreeNode firstNode = list.get(0);
- boolean isExist = false;
- // 重新转一遍,避免出现同一引用
- for (TreeNode node2 : list) {
- pathList.add(node2);
- }
-
- // 如果没有孩子节点,则直接返回,在此节点下添加子节点
- if (node.getChildNodes() == null) {
- return node;
- }
-
- for (TreeNode n : node.getChildNodes()) {
- if (n.getName().equals(firstNode.getName()) && list.size() == 1) {
- tempNode = node;
- isExist = true;
- break;
- } else if (n.getName().equals(firstNode.getName())) {
- // 还没有找到最后的位置,继续找
- pathList.remove(firstNode);
- tempNode = searchNode(n, pathList);
- return tempNode;
- }
- }
-
- // 如果没有找到,则新添加到孩子节点中
- if (!isExist) {
- tempNode = node;
- }
-
- return tempNode;
- }
-
- /**
- * 判断目前构造的FP树是否是单条路径的
- *
- * @param rootNode
- * 当前FP树的根节点
- * @return
- */
- private boolean isSinglePath(TreeNode rootNode) {
- // 默认是单条路径
- boolean isSinglePath = true;
- ArrayList childList;
- TreeNode node;
- node = rootNode;
-
- while (node.getChildNodes() != null) {
- childList = node.getChildNodes();
- if (childList.size() == 1) {
- node = childList.get(0);
- } else {
- isSinglePath = false;
- break;
- }
- }
-
- return isSinglePath;
- }
-
- /**
- * 开始构建FP树
- */
- public void startBuildingTree() {
- ArrayList singleTransaction;
- ArrayList> transactionList = new ArrayList<>();
- TreeNode tempNode;
- int count = 0;
-
- for (String[] idArray : totalGoodsID) {
- singleTransaction = new ArrayList<>();
- for (String id : idArray) {
- count = itemCountMap.get(id);
- tempNode = new TreeNode(id, count);
- singleTransaction.add(tempNode);
- }
-
- // 根据支持度数的多少进行排序
- Collections.sort(singleTransaction);
- for (TreeNode node : singleTransaction) {
- // 支持度计数重新归为1
- node.setCount(1);
- }
- transactionList.add(singleTransaction);
- }
-
- buildFPTree(null, transactionList);
- }
-
- /**
- * 输出此单条路径下的频繁模式
- *
- * @param suffixPattern
- * 后缀模式
- * @param rootNode
- * 单条路径FP树根节点
- */
- private void printFrequentPattern(ArrayList suffixPattern,
- TreeNode rootNode) {
- ArrayList idArray = new ArrayList<>();
- TreeNode temp;
- temp = rootNode;
- // 用于输出组合模式
- int length = 0;
- int num = 0;
- int[] binaryArray;
-
- while (temp.getChildNodes() != null) {
- temp = temp.getChildNodes().get(0);
-
- // 筛选支持度系数大于最小阈值的值
- if (temp.getCount() >= minSupportCount) {
- idArray.add(temp.getName());
- }
- }
-
- length = idArray.size();
- num = (int) Math.pow(2, length);
- for (int i = 0; i < num; i++) {
- binaryArray = new int[length];
- numToBinaryArray(binaryArray, i);
-
- // 如果后缀模式只有1个,不能输出自身
- if (suffixPattern.size() == 1 && i == 0) {
- continue;
- }
-
- System.out.print("频繁模式:{【后缀模式:");
- // 先输出固有的后缀模式
- if (suffixPattern.size() > 1
- || (suffixPattern.size() == 1 && idArray.size() > 0)) {
- for (String s : suffixPattern) {
- System.out.print(s + ", ");
- }
- }
- System.out.print("】");
- // 输出路径上的组合模式
- for (int j = 0; j < length; j++) {
- if (binaryArray[j] == 1) {
- System.out.print(idArray.get(j) + ", ");
- }
- }
- System.out.println("}");
- }
- }
-
- /**
- * 数字转为二进制形式
- *
- * @param binaryArray
- * 转化后的二进制数组形式
- * @param num
- * 待转化数字
- */
- private void numToBinaryArray(int[] binaryArray, int num) {
- int index = 0;
- while (num != 0) {
- binaryArray[index] = num % 2;
- index++;
- num /= 2;
- }
- }
+class FPTreeTool {
+ // 输入数据文件位置
+ private String filePath;
+ // 最小支持度阈值
+ private int minSupportCount;
+ // 所有事物ID记录
+ private ArrayList totalGoodsID;
+ // 各个ID的统计数目映射表项,计数用于排序使用
+ private HashMap itemCountMap;
+
+ FPTreeTool(String filePath, int minSupportCount){
+ this.filePath = filePath;
+ this.minSupportCount = minSupportCount;
+ readDataFile();
+ }
+
+ /**
+ * 从文件中读取数据
+ */
+ private void readDataFile(){
+ File file = new File(filePath);
+ ArrayList dataArray = new ArrayList<>();
+
+ try {
+ BufferedReader in = new BufferedReader(new FileReader(file));
+ String str;
+ String[] tempArray;
+ while ((str = in.readLine()) != null) {
+ tempArray = str.split(" ");
+ dataArray.add(tempArray);
+ }
+ in.close();
+ } catch (IOException e) {
+ e.getStackTrace();
+ }
+
+ String[] temp;
+ int count;
+ itemCountMap = new HashMap<>();
+ totalGoodsID = new ArrayList<>();
+ for (String[] a : dataArray) {
+ temp = new String[a.length - 1];
+ System.arraycopy(a, 1, temp, 0, a.length - 1);
+ totalGoodsID.add(temp);
+ for (String s : temp) {
+ if (!itemCountMap.containsKey(s)) {
+ count = 1;
+ } else {
+ count = itemCountMap.get(s);
+ // 支持度计数加1
+ count++;
+ }
+ // 更新表项
+ itemCountMap.put(s, count);
+ }
+ }
+ }
+
+ /**
+ * 根据事物记录构造FP树
+ */
+ private void buildFPTree(ArrayList suffixPattern,
+ ArrayList> transctionList){
+ // 设置一个空根节点
+ TreeNode rootNode = new TreeNode(null, 0);
+ int count;
+ // 节点是否存在
+ boolean isExist;
+ ArrayList childNodes;
+ ArrayList pathList;
+ // 相同类型节点链表,用于构造的新的FP树
+ HashMap> linkedNode = new HashMap<>();
+ HashMap countNode = new HashMap<>();
+ // 根据事物记录,一步步构建FP树
+ for (ArrayList array : transctionList) {
+ TreeNode searchedNode;
+ pathList = new ArrayList<>();
+ for (TreeNode node : array) {
+ pathList.add(node);
+ nodeCounted(node, countNode);
+ searchedNode = searchNode(rootNode, pathList);
+ childNodes = searchedNode.getChildNodes();
+
+ if (childNodes == null) {
+ childNodes = new ArrayList<>();
+ childNodes.add(node);
+ searchedNode.setChildNodes(childNodes);
+ node.setParentNode(searchedNode);
+ nodeAddToLinkedList(node, linkedNode);
+ } else {
+ isExist = false;
+ for (TreeNode node2 : childNodes) {
+ // 如果找到名称相同,则更新支持度计数
+ if (node.getName().equals(node2.getName())) {
+ count = node2.getCount() + node.getCount();
+ node2.setCount(count);
+ // 标识已找到节点位置
+ isExist = true;
+ break;
+ }
+ }
+
+ if (!isExist) {
+ // 如果没有找到,需添加子节点
+ childNodes.add(node);
+ node.setParentNode(searchedNode);
+ nodeAddToLinkedList(node, linkedNode);
+ }
+ }
+
+ }
+ }
+
+ // 如果FP树已经是单条路径,则输出此时的频繁模式
+ if (isSinglePath(rootNode)) {
+ printFrequentPattern(suffixPattern, rootNode);
+ System.out.println("-------");
+ } else {
+ ArrayList> tList;
+ ArrayList sPattern;
+ if (suffixPattern == null) {
+ sPattern = new ArrayList<>();
+ } else {
+ // 进行一个拷贝,避免互相引用的影响
+ sPattern = (ArrayList) suffixPattern.clone();
+ }
+
+ // 利用节点链表构造新的事务
+ for (Map.Entry entry : countNode.entrySet()) {
+ // 添加到后缀模式中
+ sPattern.add((String) entry.getKey());
+ //获取到了条件模式机,作为新的事务
+ tList = getTransactionList((String) entry.getKey(), linkedNode);
+
+ System.out.print("[后缀模式]:{");
+ for (String s : sPattern) {
+ System.out.print(s + ", ");
+ }
+ System.out.print("}, 此时的条件模式基:");
+ for (ArrayList tnList : tList) {
+ System.out.print("{");
+ for (TreeNode n : tnList) {
+ System.out.print(n.getName() + ", ");
+ }
+ System.out.print("}, ");
+ }
+ System.out.println();
+ // 递归构造FP树
+ buildFPTree(sPattern, tList);
+ // 再次移除此项,构造不同的后缀模式,防止对后面造成干扰
+ sPattern.remove(entry.getKey());
+ }
+ }
+ }
+
+ /**
+ * 将节点加入到同类型节点的链表中
+ *
+ * @param node 待加入节点
+ * @param linkedList 链表图
+ */
+ private void nodeAddToLinkedList(TreeNode node,
+ HashMap> linkedList){
+ String name = node.getName();
+ ArrayList list;
+
+ if (linkedList.containsKey(name)) {
+ list = linkedList.get(name);
+ // 将node添加到此队列中
+ list.add(node);
+ } else {
+ list = new ArrayList<>();
+ list.add(node);
+ linkedList.put(name, list);
+ }
+ }
+
+ /**
+ * 根据链表构造出新的事务
+ *
+ * @param name 节点名称
+ * @param linkedList 链表
+ * @return
+ */
+ private ArrayList> getTransactionList(String name,
+ HashMap> linkedList){
+ ArrayList> tList = new ArrayList<>();
+ ArrayList targetNode = linkedList.get(name);
+ ArrayList singleTansaction;
+ TreeNode temp;
+
+ for (TreeNode node : targetNode) {
+ singleTansaction = new ArrayList<>();
+
+ temp = node;
+ while (temp.getParentNode().getName() != null) {
+ temp = temp.getParentNode();
+ singleTansaction.add(new TreeNode(temp.getName(), 1));
+ }
+
+ // 按照支持度计数得反转一下
+ Collections.reverse(singleTansaction);
+
+ for (TreeNode node2 : singleTansaction) {
+ // 支持度计数调成与模式后缀一样
+ node2.setCount(node.getCount());
+ }
+
+ if (singleTansaction.size() > 0) {
+ tList.add(singleTansaction);
+ }
+ }
+
+ return tList;
+ }
+
+ /**
+ * 节点计数
+ *
+ * @param node 待加入节点
+ * @param nodeCount 计数映射图
+ */
+ private void nodeCounted(TreeNode node, HashMap nodeCount){
+ int count;
+ String name = node.getName();
+
+ if (nodeCount.containsKey(name)) {
+ count = nodeCount.get(name);
+ count++;
+ } else {
+ count = 1;
+ }
+
+ nodeCount.put(name, count);
+ }
+
+ /**
+ * 显示决策树
+ *
+ * @param node 待显示的节点
+ * @param blankNum 行空格符,用于显示树型结构
+ */
+ private void showFPTree(TreeNode node, int blankNum){
+ System.out.println();
+ for (int i = 0; i < blankNum; i++) {
+ System.out.print("\t");
+ }
+ System.out.print("--");
+ System.out.print("--");
+
+ if (node.getChildNodes() == null) {
+ System.out.print("[");
+ System.out.print("I" + node.getName() + ":" + node.getCount());
+ System.out.print("]");
+ } else {
+ // 递归显示子节点
+ // System.out.print("【" + node.getName() + "】");
+ for (TreeNode childNode : node.getChildNodes()) {
+ showFPTree(childNode, 2 * blankNum);
+ }
+ }
+
+ }
+
+ /**
+ * 待插入节点的抵达位置节点,从根节点开始向下寻找待插入节点的位置
+ *
+ * @param node
+ * @param list
+ * @return
+ */
+ private TreeNode searchNode(TreeNode node, ArrayList list){
+ ArrayList pathList = new ArrayList<>();
+ TreeNode tempNode = null;
+ TreeNode firstNode = list.get(0);
+ boolean isExist = false;
+ // 重新转一遍,避免出现同一引用
+ for (TreeNode node2 : list) {
+ pathList.add(node2);
+ }
+
+ // 如果没有孩子节点,则直接返回,在此节点下添加子节点
+ if (node.getChildNodes() == null) {
+ return node;
+ }
+
+ for (TreeNode n : node.getChildNodes()) {
+ if (n.getName().equals(firstNode.getName()) && list.size() == 1) {
+ tempNode = node;
+ isExist = true;
+ break;
+ } else if (n.getName().equals(firstNode.getName())) {
+ // 还没有找到最后的位置,继续找
+ pathList.remove(firstNode);
+ tempNode = searchNode(n, pathList);
+ return tempNode;
+ }
+ }
+
+ // 如果没有找到,则新添加到孩子节点中
+ if (!isExist) {
+ tempNode = node;
+ }
+
+ return tempNode;
+ }
+
+ /**
+ * 判断目前构造的FP树是否是单条路径的
+ *
+ * @param rootNode 当前FP树的根节点
+ * @return
+ */
+ private boolean isSinglePath(TreeNode rootNode){
+ // 默认是单条路径
+ boolean isSinglePath = true;
+ ArrayList childList;
+ TreeNode node;
+ node = rootNode;
+
+ while (node.getChildNodes() != null) {
+ childList = node.getChildNodes();
+ if (childList.size() == 1) {
+ node = childList.get(0);
+ } else {
+ isSinglePath = false;
+ break;
+ }
+ }
+
+ return isSinglePath;
+ }
+
+ /**
+ * 开始构建FP树
+ */
+ void startBuildingTree(){
+ ArrayList singleTransaction;
+ ArrayList> transactionList = new ArrayList<>();
+ TreeNode tempNode;
+ int count;
+
+ for (String[] idArray : totalGoodsID) {
+ singleTransaction = new ArrayList<>();
+ for (String id : idArray) {
+ count = itemCountMap.get(id);
+ tempNode = new TreeNode(id, count);
+ singleTransaction.add(tempNode);
+ }
+
+ // 根据支持度数的多少进行排序
+ Collections.sort(singleTransaction);
+ for (TreeNode node : singleTransaction) {
+ // 支持度计数重新归为1
+ node.setCount(1);
+ }
+ transactionList.add(singleTransaction);
+ }
+
+ buildFPTree(null, transactionList);
+ }
+
+ /**
+ * 输出此单条路径下的频繁模式
+ *
+ * @param suffixPattern 后缀模式
+ * @param rootNode 单条路径FP树根节点
+ */
+ private void printFrequentPattern(ArrayList suffixPattern,
+ TreeNode rootNode){
+ ArrayList idArray = new ArrayList<>();
+ TreeNode temp;
+ temp = rootNode;
+ // 用于输出组合模式
+ int length;
+ int num;
+ int[] binaryArray;
+
+ while (temp.getChildNodes() != null) {
+ temp = temp.getChildNodes().get(0);
+
+ // 筛选支持度系数大于最小阈值的值
+ if (temp.getCount() >= minSupportCount) {
+ idArray.add(temp.getName());
+ }
+ }
+
+ length = idArray.size();
+ num = (int) Math.pow(2, length);
+ for (int i = 0; i < num; i++) {
+ binaryArray = new int[length];
+ numToBinaryArray(binaryArray, i);
+
+ // 如果后缀模式只有1个,不能输出自身
+ if (suffixPattern.size() == 1 && i == 0) {
+ continue;
+ }
+
+ System.out.print("频繁模式:{【后缀模式:");
+ // 先输出固有的后缀模式
+ if (suffixPattern.size() > 1
+ || (suffixPattern.size() == 1 && idArray.size() > 0)) {
+ for (String s : suffixPattern) {
+ System.out.print(s + ", ");
+ }
+ }
+ System.out.print("】");
+ // 输出路径上的组合模式
+ for (int j = 0; j < length; j++) {
+ if (binaryArray[j] == 1) {
+ System.out.print(idArray.get(j) + ", ");
+ }
+ }
+ System.out.println("}");
+ }
+ }
+
+ /**
+ * 数字转为二进制形式
+ *
+ * @param binaryArray 转化后的二进制数组形式
+ * @param num 待转化数字
+ */
+ private void numToBinaryArray(int[] binaryArray, int num){
+ int index = 0;
+ while (num != 0) {
+ binaryArray[index] = num % 2;
+ index++;
+ num /= 2;
+ }
+ }
}
diff --git a/AssociationAnalysis/DataMining_FPTree/TreeNode.java b/AssociationAnalysis/DataMining_FPTree/TreeNode.java
index 154a94d..7201922 100644
--- a/AssociationAnalysis/DataMining_FPTree/TreeNode.java
+++ b/AssociationAnalysis/DataMining_FPTree/TreeNode.java
@@ -1,14 +1,14 @@
-package DataMining_FPTree;
+package AssociationAnalysis.DataMining_FPTree;
import java.util.ArrayList;
/**
* FP树节点
*
- * @author lyq
+ * @author Qstar
*
*/
-public class TreeNode implements Comparable, Cloneable{
+class TreeNode implements Comparable, Cloneable{
// 节点类别名称
private String name;
// 计数数量
@@ -18,7 +18,7 @@ public class TreeNode implements Comparable, Cloneable{
// 孩子节点,可以为多个
private ArrayList childNodes;
- public TreeNode(String name, int count){
+ TreeNode(String name, int count){
this.name = name;
this.count = count;
}
@@ -27,11 +27,7 @@ public String getName() {
return name;
}
- public void setName(String name) {
- this.name = name;
- }
-
- public Integer getCount() {
+ public Integer getCount() {
return count;
}
@@ -39,19 +35,19 @@ public void setCount(Integer count) {
this.count = count;
}
- public TreeNode getParentNode() {
+ TreeNode getParentNode() {
return parentNode;
}
- public void setParentNode(TreeNode parentNode) {
+ void setParentNode(TreeNode parentNode) {
this.parentNode = parentNode;
}
- public ArrayList getChildNodes() {
+ ArrayList getChildNodes() {
return childNodes;
}
- public void setChildNodes(ArrayList childNodes) {
+ void setChildNodes(ArrayList childNodes) {
this.childNodes = childNodes;
}
diff --git a/BaggingAndBoosting/DataMining_AdaBoost/AdaBoostTool.java b/BaggingAndBoosting/DataMining_AdaBoost/AdaBoostTool.java
index 9731e35..6665542 100644
--- a/BaggingAndBoosting/DataMining_AdaBoost/AdaBoostTool.java
+++ b/BaggingAndBoosting/DataMining_AdaBoost/AdaBoostTool.java
@@ -1,4 +1,4 @@
-package DataMining_AdaBoost;
+package BaggingAndBoosting.DataMining_AdaBoost;
import java.io.BufferedReader;
import java.io.File;
@@ -11,303 +11,293 @@
/**
* AdaBoost提升算法工具类
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class AdaBoostTool {
- // 分类的类别,程序默认为正类1和负类-1
- public static final int CLASS_POSITIVE = 1;
- public static final int CLASS_NEGTIVE = -1;
-
- // 事先假设的3个分类器(理论上应该重新对数据集进行训练得到)
- public static final String CLASSIFICATION1 = "X=2.5";
- public static final String CLASSIFICATION2 = "X=7.5";
- public static final String CLASSIFICATION3 = "Y=5.5";
-
- // 分类器组
- public static final String[] ClASSIFICATION = new String[] {
- CLASSIFICATION1, CLASSIFICATION2, CLASSIFICATION3 };
- // 分类权重组
- private double[] CLASSIFICATION_WEIGHT;
-
- // 测试数据文件地址
- private String filePath;
- // 误差率阈值
- private double errorValue;
- // 所有的数据点
- private ArrayList totalPoint;
-
- public AdaBoostTool(String filePath, double errorValue) {
- this.filePath = filePath;
- this.errorValue = errorValue;
- readDataFile();
- }
-
- /**
- * 从文件中读取数据
- */
- private void readDataFile() {
- File file = new File(filePath);
- ArrayList dataArray = new ArrayList();
-
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
-
- Point temp;
- totalPoint = new ArrayList<>();
- for (String[] array : dataArray) {
- temp = new Point(array[0], array[1], array[2]);
- temp.setProbably(1.0 / dataArray.size());
- totalPoint.add(temp);
- }
- }
-
- /**
- * 根据当前的误差值算出所得的权重
- *
- * @param errorValue
- * 当前划分的坐标点误差率
- * @return
- */
- private double calculateWeight(double errorValue) {
- double alpha = 0;
- double temp = 0;
-
- temp = (1 - errorValue) / errorValue;
- alpha = 0.5 * Math.log(temp);
-
- return alpha;
- }
-
- /**
- * 计算当前划分的误差率
- *
- * @param pointMap
- * 划分之后的点集
- * @param weight
- * 本次划分得到的分类器权重
- * @return
- */
- private double calculateErrorValue(
- HashMap> pointMap) {
- double resultValue = 0;
- double temp = 0;
- double weight = 0;
- int tempClassType;
- ArrayList pList;
- for (Map.Entry entry : pointMap.entrySet()) {
- tempClassType = (int) entry.getKey();
-
- pList = (ArrayList) entry.getValue();
- for (Point p : pList) {
- temp = p.getProbably();
- // 如果划分类型不相等,代表划错了
- if (tempClassType != p.getClassType()) {
- resultValue += temp;
- }
- }
- }
-
- weight = calculateWeight(resultValue);
- for (Map.Entry entry : pointMap.entrySet()) {
- tempClassType = (int) entry.getKey();
-
- pList = (ArrayList) entry.getValue();
- for (Point p : pList) {
- temp = p.getProbably();
- // 如果划分类型不相等,代表划错了
- if (tempClassType != p.getClassType()) {
- // 划错的点的权重比例变大
- temp *= Math.exp(weight);
- p.setProbably(temp);
- } else {
- // 划对的点的权重比减小
- temp *= Math.exp(-weight);
- p.setProbably(temp);
- }
- }
- }
-
- // 如果误差率没有小于阈值,继续处理
- dataNormalized();
-
- return resultValue;
- }
-
- /**
- * 概率做归一化处理
- */
- private void dataNormalized() {
- double sumProbably = 0;
- double temp = 0;
-
- for (Point p : totalPoint) {
- sumProbably += p.getProbably();
- }
-
- // 归一化处理
- for (Point p : totalPoint) {
- temp = p.getProbably();
- p.setProbably(temp / sumProbably);
- }
- }
-
- /**
- * 用AdaBoost算法得到的组合分类器对数据进行分类
- *
- */
- public void adaBoostClassify() {
- double value = 0;
- Point p;
-
- calculateWeightArray();
- for (int i = 0; i < ClASSIFICATION.length; i++) {
- System.out.println(MessageFormat.format("分类器{0}权重为:{1}", (i+1), CLASSIFICATION_WEIGHT[i]));
- }
-
- for (int j = 0; j < totalPoint.size(); j++) {
- p = totalPoint.get(j);
- value = 0;
-
- for (int i = 0; i < ClASSIFICATION.length; i++) {
- value += 1.0 * classifyData(ClASSIFICATION[i], p)
- * CLASSIFICATION_WEIGHT[i];
- }
-
- //进行符号判断
- if (value > 0) {
- System.out
- .println(MessageFormat.format(
- "点({0}, {1})的组合分类结果为:1,该点的实际分类为{2}", p.getX(), p.getY(),
- p.getClassType()));
- } else {
- System.out.println(MessageFormat.format(
- "点({0}, {1})的组合分类结果为:-1,该点的实际分类为{2}", p.getX(), p.getY(),
- p.getClassType()));
- }
- }
- }
-
- /**
- * 计算分类器权重数组
- */
- private void calculateWeightArray() {
- int tempClassType = 0;
- double errorValue = 0;
- ArrayList posPointList;
- ArrayList negPointList;
- HashMap> mapList;
- CLASSIFICATION_WEIGHT = new double[ClASSIFICATION.length];
-
- for (int i = 0; i < CLASSIFICATION_WEIGHT.length; i++) {
- mapList = new HashMap<>();
- posPointList = new ArrayList<>();
- negPointList = new ArrayList<>();
-
- for (Point p : totalPoint) {
- tempClassType = classifyData(ClASSIFICATION[i], p);
-
- if (tempClassType == CLASS_POSITIVE) {
- posPointList.add(p);
- } else {
- negPointList.add(p);
- }
- }
-
- mapList.put(CLASS_POSITIVE, posPointList);
- mapList.put(CLASS_NEGTIVE, negPointList);
-
- if (i == 0) {
- // 最开始的各个点的权重一样,所以传入0,使得e的0次方等于1
- errorValue = calculateErrorValue(mapList);
- } else {
- // 每次把上次计算所得的权重代入,进行概率的扩大或缩小
- errorValue = calculateErrorValue(mapList);
- }
-
- // 计算当前分类器的所得权重
- CLASSIFICATION_WEIGHT[i] = calculateWeight(errorValue);
- }
- }
-
- /**
- * 用各个子分类器进行分类
- *
- * @param classification
- * 分类器名称
- * @param p
- * 待划分坐标点
- * @return
- */
- private int classifyData(String classification, Point p) {
- // 分割线所属坐标轴
- String position;
- // 分割线的值
- double value = 0;
- double posProbably = 0;
- double negProbably = 0;
- // 划分是否是大于一边的划分
- boolean isLarger = false;
- String[] array;
- ArrayList pList = new ArrayList<>();
-
- array = classification.split("=");
- position = array[0];
- value = Double.parseDouble(array[1]);
-
- if (position.equals("X")) {
- if (p.getX() > value) {
- isLarger = true;
- }
-
- // 将训练数据中所有属于这边的点加入
- for (Point point : totalPoint) {
- if (isLarger && point.getX() > value) {
- pList.add(point);
- } else if (!isLarger && point.getX() < value) {
- pList.add(point);
- }
- }
- } else if (position.equals("Y")) {
- if (p.getY() > value) {
- isLarger = true;
- }
-
- // 将训练数据中所有属于这边的点加入
- for (Point point : totalPoint) {
- if (isLarger && point.getY() > value) {
- pList.add(point);
- } else if (!isLarger && point.getY() < value) {
- pList.add(point);
- }
- }
- }
-
- for (Point p2 : pList) {
- if (p2.getClassType() == CLASS_POSITIVE) {
- posProbably++;
- } else {
- negProbably++;
- }
- }
-
- //分类按正负类数量进行划分
- if (posProbably > negProbably) {
- return CLASS_POSITIVE;
- } else {
- return CLASS_NEGTIVE;
- }
- }
+class AdaBoostTool {
+ // 分类的类别,程序默认为正类1和负类-1
+ private static final int CLASS_POSITIVE = 1;
+ private static final int CLASS_NEGTIVE = -1;
+
+ // 事先假设的3个分类器(理论上应该重新对数据集进行训练得到)
+ private static final String CLASSIFICATION1 = "X=2.5";
+ private static final String CLASSIFICATION2 = "X=7.5";
+ private static final String CLASSIFICATION3 = "Y=5.5";
+
+ // 分类器组
+ private static final String[] ClASSIFICATION = new String[]{
+ CLASSIFICATION1, CLASSIFICATION2, CLASSIFICATION3};
+ // 分类权重组
+ private double[] CLASSIFICATION_WEIGHT;
+
+ // 测试数据文件地址
+ private String filePath;
+ // 误差率阈值
+ private double errorValue;
+ // 所有的数据点
+ private ArrayList totalPoint;
+
+ AdaBoostTool(String filePath, double errorValue){
+ this.filePath = filePath;
+ this.errorValue = errorValue;
+ readDataFile();
+ }
+
+ /**
+ * 从文件中读取数据
+ */
+ private void readDataFile(){
+ File file = new File(filePath);
+ ArrayList dataArray = new ArrayList<>();
+
+ try {
+ BufferedReader in = new BufferedReader(new FileReader(file));
+ String str;
+ String[] tempArray;
+ while ((str = in.readLine()) != null) {
+ tempArray = str.split(" ");
+ dataArray.add(tempArray);
+ }
+ in.close();
+ } catch (IOException e) {
+ e.getStackTrace();
+ }
+
+ Point temp;
+ totalPoint = new ArrayList<>();
+ for (String[] array : dataArray) {
+ temp = new Point(array[0], array[1], array[2]);
+ temp.setProbably(1.0 / dataArray.size());
+ totalPoint.add(temp);
+ }
+ }
+
+ /**
+ * 根据当前的误差值算出所得的权重
+ *
+ * @param errorValue 当前划分的坐标点误差率
+ * @return
+ */
+ private double calculateWeight(double errorValue){
+ double alpha;
+ double temp;
+
+ temp = (1 - errorValue) / errorValue;
+ alpha = 0.5 * Math.log(temp);
+
+ return alpha;
+ }
+
+ /**
+ * @param pointMap
+ * @return weight 本次划分得到的分类器权重
+ */
+ private double calculateErrorValue(
+ HashMap> pointMap){
+ double resultValue = 0;
+ double temp;
+ double weight;
+ int tempClassType;
+ ArrayList pList;
+ for (Map.Entry entry : pointMap.entrySet()) {
+ tempClassType = (int) entry.getKey();
+
+ pList = (ArrayList) entry.getValue();
+ for (Point p : pList) {
+ temp = p.getProbably();
+ // 如果划分类型不相等,代表划错了
+ if (tempClassType != p.getClassType()) {
+ resultValue += temp;
+ }
+ }
+ }
+
+ weight = calculateWeight(resultValue);
+ for (Map.Entry entry : pointMap.entrySet()) {
+ tempClassType = (int) entry.getKey();
+
+ pList = (ArrayList) entry.getValue();
+ for (Point p : pList) {
+ temp = p.getProbably();
+ // 如果划分类型不相等,代表划错了
+ if (tempClassType != p.getClassType()) {
+ // 划错的点的权重比例变大
+ temp *= Math.exp(weight);
+ p.setProbably(temp);
+ } else {
+ // 划对的点的权重比减小
+ temp *= Math.exp(-weight);
+ p.setProbably(temp);
+ }
+ }
+ }
+
+ // 如果误差率没有小于阈值,继续处理
+ dataNormalized();
+
+ return resultValue;
+ }
+
+ /**
+ * 概率做归一化处理
+ */
+ private void dataNormalized(){
+ double sumProbably = 0;
+ double temp;
+
+ for (Point p : totalPoint) {
+ sumProbably += p.getProbably();
+ }
+
+ // 归一化处理
+ for (Point p : totalPoint) {
+ temp = p.getProbably();
+ p.setProbably(temp / sumProbably);
+ }
+ }
+
+ /**
+ * 用AdaBoost算法得到的组合分类器对数据进行分类
+ */
+ void adaBoostClassify(){
+ double value;
+ Point p;
+
+ calculateWeightArray();
+ for (int i = 0; i < ClASSIFICATION.length; i++) {
+ System.out.println(MessageFormat.format("分类器{0}权重为:{1}", (i + 1), CLASSIFICATION_WEIGHT[i]));
+ }
+
+ for (Point aTotalPoint : totalPoint) {
+ p = aTotalPoint;
+ value = 0;
+
+ for (int i = 0; i < ClASSIFICATION.length; i++) {
+ value += 1.0 * classifyData(ClASSIFICATION[i], p)
+ * CLASSIFICATION_WEIGHT[i];
+ }
+
+ //进行符号判断
+ if (value > 0) {
+ System.out
+ .println(MessageFormat.format(
+ "点({0}, {1})的组合分类结果为:1,该点的实际分类为{2}", p.getX(), p.getY(),
+ p.getClassType()));
+ } else {
+ System.out.println(MessageFormat.format(
+ "点({0}, {1})的组合分类结果为:-1,该点的实际分类为{2}", p.getX(), p.getY(),
+ p.getClassType()));
+ }
+ }
+ }
+
+ /**
+ * 计算分类器权重数组
+ */
+ private void calculateWeightArray(){
+ int tempClassType;
+ double errorValue;
+ ArrayList posPointList;
+ ArrayList negPointList;
+ HashMap> mapList;
+ CLASSIFICATION_WEIGHT = new double[ClASSIFICATION.length];
+
+ for (int i = 0; i < CLASSIFICATION_WEIGHT.length; i++) {
+ mapList = new HashMap<>();
+ posPointList = new ArrayList<>();
+ negPointList = new ArrayList<>();
+
+ for (Point p : totalPoint) {
+ tempClassType = classifyData(ClASSIFICATION[i], p);
+
+ if (tempClassType == CLASS_POSITIVE) {
+ posPointList.add(p);
+ } else {
+ negPointList.add(p);
+ }
+ }
+
+ mapList.put(CLASS_POSITIVE, posPointList);
+ mapList.put(CLASS_NEGTIVE, negPointList);
+
+ if (i == 0) {
+ // 最开始的各个点的权重一样,所以传入0,使得e的0次方等于1
+ errorValue = calculateErrorValue(mapList);
+ } else {
+ // 每次把上次计算所得的权重代入,进行概率的扩大或缩小
+ errorValue = calculateErrorValue(mapList);
+ }
+
+ // 计算当前分类器的所得权重
+ CLASSIFICATION_WEIGHT[i] = calculateWeight(errorValue);
+ }
+ }
+
+ /**
+ * 用各个子分类器进行分类
+ *
+ * @param classification 分类器名称
+ * @param p 待划分坐标点
+ * @return
+ */
+ private int classifyData(String classification, Point p){
+ // 分割线所属坐标轴
+ String position;
+ // 分割线的值
+ double value;
+ double posProbably = 0;
+ double negProbably = 0;
+ // 划分是否是大于一边的划分
+ boolean isLarger = false;
+ String[] array;
+ ArrayList pList = new ArrayList<>();
+
+ array = classification.split("=");
+ position = array[0];
+ value = Double.parseDouble(array[1]);
+
+ if (position.equals("X")) {
+ if (p.getX() > value) {
+ isLarger = true;
+ }
+
+ // 将训练数据中所有属于这边的点加入
+ for (Point point : totalPoint) {
+ if (isLarger && point.getX() > value) {
+ pList.add(point);
+ } else if (!isLarger && point.getX() < value) {
+ pList.add(point);
+ }
+ }
+ } else if (position.equals("Y")) {
+ if (p.getY() > value) {
+ isLarger = true;
+ }
+
+ // 将训练数据中所有属于这边的点加入
+ for (Point point : totalPoint) {
+ if (isLarger && point.getY() > value) {
+ pList.add(point);
+ } else if (!isLarger && point.getY() < value) {
+ pList.add(point);
+ }
+ }
+ }
+
+ for (Point p2 : pList) {
+ if (p2.getClassType() == CLASS_POSITIVE) {
+ posProbably++;
+ } else {
+ negProbably++;
+ }
+ }
+
+ //分类按正负类数量进行划分
+ if (posProbably > negProbably) {
+ return CLASS_POSITIVE;
+ } else {
+ return CLASS_NEGTIVE;
+ }
+ }
}
diff --git a/BaggingAndBoosting/DataMining_AdaBoost/Client.java b/BaggingAndBoosting/DataMining_AdaBoost/Client.java
index 30e98ec..f00bcdc 100644
--- a/BaggingAndBoosting/DataMining_AdaBoost/Client.java
+++ b/BaggingAndBoosting/DataMining_AdaBoost/Client.java
@@ -1,13 +1,13 @@
-package DataMining_AdaBoost;
+package BaggingAndBoosting.DataMining_AdaBoost;
/**
* AdaBoost提升算法调用类
- * @author lyq
+ * @author Qstar
*
*/
public class Client {
public static void main(String[] agrs){
- String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
+ String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/BaggingAndBoosting/DataMining_AdaBoost/input.txt";
//误差率阈值
double errorValue = 0.2;
diff --git a/BaggingAndBoosting/DataMining_AdaBoost/Point.java b/BaggingAndBoosting/DataMining_AdaBoost/Point.java
index 7d352ab..f2fdac5 100644
--- a/BaggingAndBoosting/DataMining_AdaBoost/Point.java
+++ b/BaggingAndBoosting/DataMining_AdaBoost/Point.java
@@ -1,62 +1,57 @@
-package DataMining_AdaBoost;
+package BaggingAndBoosting.DataMining_AdaBoost;
/**
* 坐标点类
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
public class Point {
- // 坐标点x坐标
- private int x;
- // 坐标点y坐标
- private int y;
- // 坐标点的分类类别
- private int classType;
- //如果此节点被划错,他的误差率,不能用个数除以总数,因为不同坐标点的权重不一定相等
- private double probably;
-
- public Point(int x, int y, int classType){
- this.x = x;
- this.y = y;
- this.classType = classType;
- }
-
- public Point(String x, String y, String classType){
- this.x = Integer.parseInt(x);
- this.y = Integer.parseInt(y);
- this.classType = Integer.parseInt(classType);
- }
-
- public int getX() {
- return x;
- }
-
- public void setX(int x) {
- this.x = x;
- }
-
- public int getY() {
- return y;
- }
-
- public void setY(int y) {
- this.y = y;
- }
-
- public int getClassType() {
- return classType;
- }
-
- public void setClassType(int classType) {
- this.classType = classType;
- }
-
- public double getProbably() {
- return probably;
- }
-
- public void setProbably(double probably) {
- this.probably = probably;
- }
+ // 坐标点x坐标
+ private int x;
+ // 坐标点y坐标
+ private int y;
+ // 坐标点的分类类别
+ private int classType;
+ //如果此节点被划错,他的误差率,不能用个数除以总数,因为不同坐标点的权重不一定相等
+ private double probably;
+
+ public Point(int x, int y, int classType){
+ this.x = x;
+ this.y = y;
+ this.classType = classType;
+ }
+
+ public Point(String x, String y, String classType){
+ this.x = Integer.parseInt(x);
+ this.y = Integer.parseInt(y);
+ this.classType = Integer.parseInt(classType);
+ }
+
+ public int getX(){
+ return x;
+ }
+
+ public void setX(int x){
+ this.x = x;
+ }
+
+ public int getY(){
+ return y;
+ }
+
+ public void setY(int y){
+ this.y = y;
+ }
+
+ int getClassType(){
+ return classType;
+ }
+
+ double getProbably(){
+ return probably;
+ }
+
+ void setProbably(double probably){
+ this.probably = probably;
+ }
}
diff --git a/Classification/DataMining_CART/AttrNode.java b/Classification/DataMining_CART/AttrNode.java
index c259ff1..5e60f9d 100644
--- a/Classification/DataMining_CART/AttrNode.java
+++ b/Classification/DataMining_CART/AttrNode.java
@@ -1,85 +1,83 @@
-package DataMining_CART;
+package Classification.DataMining_CART;
import java.util.ArrayList;
/**
* 回归分类树节点
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class AttrNode {
- // 节点属性名字
- private String attrName;
- // 节点索引标号
- private int nodeIndex;
- //包含的叶子节点数
- private int leafNum;
- // 节点误差率
- private double alpha;
- // 父亲分类属性值
- private String parentAttrValue;
- // 孩子节点
- private AttrNode[] childAttrNode;
- // 数据记录索引
- private ArrayList dataIndex;
-
- public String getAttrName() {
- return attrName;
- }
-
- public void setAttrName(String attrName) {
- this.attrName = attrName;
- }
-
- public int getNodeIndex() {
- return nodeIndex;
- }
-
- public void setNodeIndex(int nodeIndex) {
- this.nodeIndex = nodeIndex;
- }
-
- public double getAlpha() {
- return alpha;
- }
-
- public void setAlpha(double alpha) {
- this.alpha = alpha;
- }
-
- public String getParentAttrValue() {
- return parentAttrValue;
- }
-
- public void setParentAttrValue(String parentAttrValue) {
- this.parentAttrValue = parentAttrValue;
- }
-
- public AttrNode[] getChildAttrNode() {
- return childAttrNode;
- }
-
- public void setChildAttrNode(AttrNode[] childAttrNode) {
- this.childAttrNode = childAttrNode;
- }
-
- public ArrayList getDataIndex() {
- return dataIndex;
- }
-
- public void setDataIndex(ArrayList dataIndex) {
- this.dataIndex = dataIndex;
- }
-
- public int getLeafNum() {
- return leafNum;
- }
-
- public void setLeafNum(int leafNum) {
- this.leafNum = leafNum;
- }
-
-
-
+class AttrNode {
+ // 节点属性名字
+ private String attrName;
+ // 节点索引标号
+ private int nodeIndex;
+ //包含的叶子节点数
+ private int leafNum;
+ // 节点误差率
+ private double alpha;
+ // 父亲分类属性值
+ private String parentAttrValue;
+ // 孩子节点
+ private AttrNode[] childAttrNode;
+ // 数据记录索引
+ private ArrayList dataIndex;
+
+ public String getAttrName(){
+ return attrName;
+ }
+
+ public void setAttrName(String attrName){
+ this.attrName = attrName;
+ }
+
+ int getNodeIndex(){
+ return nodeIndex;
+ }
+
+ void setNodeIndex(int nodeIndex){
+ this.nodeIndex = nodeIndex;
+ }
+
+ double getAlpha(){
+ return alpha;
+ }
+
+ void setAlpha(double alpha){
+ this.alpha = alpha;
+ }
+
+ String getParentAttrValue(){
+ return parentAttrValue;
+ }
+
+ void setParentAttrValue(String parentAttrValue){
+ this.parentAttrValue = parentAttrValue;
+ }
+
+ AttrNode[] getChildAttrNode(){
+ return childAttrNode;
+ }
+
+ void setChildAttrNode(AttrNode[] childAttrNode){
+ this.childAttrNode = childAttrNode;
+ }
+
+ ArrayList getDataIndex(){
+ return dataIndex;
+ }
+
+ void setDataIndex(ArrayList dataIndex){
+ this.dataIndex = dataIndex;
+ }
+
+ int getLeafNum(){
+ return leafNum;
+ }
+
+ void setLeafNum(int leafNum){
+ this.leafNum = leafNum;
+ }
+
+
}
diff --git a/Classification/DataMining_CART/CARTTool.java b/Classification/DataMining_CART/CARTTool.java
index 2ac9d7a..3ecf0f6 100644
--- a/Classification/DataMining_CART/CARTTool.java
+++ b/Classification/DataMining_CART/CARTTool.java
@@ -1,546 +1,514 @@
-package DataMining_CART;
+package Classification.DataMining_CART;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.LinkedList;
-import java.util.Map;
-import java.util.Queue;
-
-import javax.lang.model.element.NestingKind;
-import javax.swing.text.DefaultEditorKit.CutAction;
-import javax.swing.text.html.MinimalHTMLWriter;
+import java.util.*;
/**
* CART分类回归树算法工具类
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class CARTTool {
- // 类标号的值类型
- private final String YES = "Yes";
- private final String NO = "No";
-
- // 所有属性的类型总数,在这里就是data源数据的列数
- private int attrNum;
- private String filePath;
- // 初始源数据,用一个二维字符数组存放模仿表格数据
- private String[][] data;
- // 数据的属性行的名字
- private String[] attrNames;
- // 每个属性的值所有类型
- private HashMap> attrValue;
-
- public CARTTool(String filePath) {
- this.filePath = filePath;
- attrValue = new HashMap<>();
- }
-
- /**
- * 从文件中读取数据
- */
- public void readDataFile() {
- File file = new File(filePath);
- ArrayList dataArray = new ArrayList();
-
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
-
- data = new String[dataArray.size()][];
- dataArray.toArray(data);
- attrNum = data[0].length;
- attrNames = data[0];
+class CARTTool {
+ // 类标号的值类型
+ private final String YES = "Yes";
+ private final String NO = "No";
+
+ // 所有属性的类型总数,在这里就是data源数据的列数
+ private int attrNum;
+ private String filePath;
+ // 初始源数据,用一个二维字符数组存放模仿表格数据
+ private String[][] data;
+ // 数据的属性行的名字
+ private String[] attrNames;
+ // 每个属性的值所有类型
+ private HashMap> attrValue;
+
+ CARTTool(String filePath){
+ this.filePath = filePath;
+ attrValue = new HashMap<>();
+ }
+
+ /**
+ * 从文件中读取数据
+ */
+ public void readDataFile(){
+ File file = new File(filePath);
+ ArrayList dataArray = new ArrayList<>();
+
+ try {
+ BufferedReader in = new BufferedReader(new FileReader(file));
+ String str;
+ String[] tempArray;
+ while ((str = in.readLine()) != null) {
+ tempArray = str.split(" ");
+ dataArray.add(tempArray);
+ }
+ in.close();
+ } catch (IOException e) {
+ e.getStackTrace();
+ }
+
+ data = new String[dataArray.size()][];
+ dataArray.toArray(data);
+ attrNum = data[0].length;
+ attrNames = data[0];
/*
- * for (int i = 0; i < data.length; i++) { for (int j = 0; j <
+ * for (int i = 0; i < data.length; i++) { for (int j = 0; j <
* data[0].length; j++) { System.out.print(" " + data[i][j]); }
* System.out.print("\n"); }
*/
- }
-
- /**
- * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
- */
- public void initAttrValue() {
- ArrayList tempValues;
-
- // 按照列的方式,从左往右找
- for (int j = 1; j < attrNum; j++) {
- // 从一列中的上往下开始寻找值
- tempValues = new ArrayList<>();
- for (int i = 1; i < data.length; i++) {
- if (!tempValues.contains(data[i][j])) {
- // 如果这个属性的值没有添加过,则添加
- tempValues.add(data[i][j]);
- }
- }
-
- // 一列属性的值已经遍历完毕,复制到map属性表中
- attrValue.put(data[0][j], tempValues);
- }
+ }
+
+ /**
+ * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
+ */
+ private void initAttrValue(){
+ ArrayList tempValues;
+
+ // 按照列的方式,从左往右找
+ for (int j = 1; j < attrNum; j++) {
+ // 从一列中的上往下开始寻找值
+ tempValues = new ArrayList<>();
+ for (int i = 1; i < data.length; i++) {
+ if (!tempValues.contains(data[i][j])) {
+ // 如果这个属性的值没有添加过,则添加
+ tempValues.add(data[i][j]);
+ }
+ }
+
+ // 一列属性的值已经遍历完毕,复制到map属性表中
+ attrValue.put(data[0][j], tempValues);
+ }
/*
- * for (Map.Entry entry : attrValue.entrySet()) {
+ * for (Map.Entry entry : attrValue.entrySet()) {
* System.out.println("key:value " + entry.getKey() + ":" +
* entry.getValue()); }
*/
- }
-
- /**
- * 计算机基尼指数
- *
- * @param remainData
- * 剩余数据
- * @param attrName
- * 属性名称
- * @param value
- * 属性值
- * @param beLongValue
- * 分类是否属于此属性值
- * @return
- */
- public double computeGini(String[][] remainData, String attrName,
- String value, boolean beLongValue) {
- // 实例总数
- int total = 0;
- // 正实例数
- int posNum = 0;
- // 负实例数
- int negNum = 0;
- // 基尼指数
- double gini = 0;
-
- // 还是按列从左往右遍历属性
- for (int j = 1; j < attrNames.length; j++) {
- // 找到了指定的属性
- if (attrName.equals(attrNames[j])) {
- for (int i = 1; i < remainData.length; i++) {
- // 统计正负实例按照属于和不属于值类型进行划分
- if ((beLongValue && remainData[i][j].equals(value))
- || (!beLongValue && !remainData[i][j].equals(value))) {
- if (remainData[i][attrNames.length - 1].equals(YES)) {
- // 判断此行数据是否为正实例
- posNum++;
- } else {
- negNum++;
- }
- }
- }
- }
- }
-
- total = posNum + negNum;
- double posProbobly = (double) posNum / total;
- double negProbobly = (double) negNum / total;
- gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;
-
- // 返回计算基尼指数
- return gini;
- }
-
- /**
- * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中
- *
- * @param remainData
- * 剩余谁
- * @param attrName
- * 属性名称
- * @return
- */
- public String[] computeAttrGini(String[][] remainData, String attrName) {
- String[] str = new String[2];
- // 最终该属性的划分类型值
- String spiltValue = "";
- // 临时变量
- int tempNum = 0;
- // 保存属性的值划分时的最小的基尼指数
- double minGini = Integer.MAX_VALUE;
- ArrayList valueTypes = attrValue.get(attrName);
- // 属于此属性值的实例数
- HashMap belongNum = new HashMap<>();
-
- for (String string : valueTypes) {
- // 重新计数的时候,数字归0
- tempNum = 0;
- // 按列从左往右遍历属性
- for (int j = 1; j < attrNames.length; j++) {
- // 找到了指定的属性
- if (attrName.equals(attrNames[j])) {
- for (int i = 1; i < remainData.length; i++) {
- // 统计正负实例按照属于和不属于值类型进行划分
- if (remainData[i][j].equals(string)) {
- tempNum++;
- }
- }
- }
- }
-
- belongNum.put(string, tempNum);
- }
-
- double tempGini = 0;
- double posProbably = 1.0;
- double negProbably = 1.0;
- for (String string : valueTypes) {
- tempGini = 0;
-
- posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);
- negProbably = 1 - posProbably;
-
- tempGini += posProbably
- * computeGini(remainData, attrName, string, true);
- tempGini += negProbably
- * computeGini(remainData, attrName, string, false);
-
- if (tempGini < minGini) {
- minGini = tempGini;
- spiltValue = string;
- }
- }
-
- str[0] = spiltValue;
- str[1] = minGini + "";
-
- return str;
- }
-
- public void buildDecisionTree(AttrNode node, String parentAttrValue,
- String[][] remainData, ArrayList remainAttr,
- boolean beLongParentValue) {
- // 属性划分值
- String valueType = "";
- // 划分属性名称
- String spiltAttrName = "";
- double minGini = Integer.MAX_VALUE;
- double tempGini = 0;
- // 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值
- String[] giniArray;
-
- if (beLongParentValue) {
- node.setParentAttrValue(parentAttrValue);
- } else {
- node.setParentAttrValue("!" + parentAttrValue);
- }
-
- if (remainAttr.size() == 0) {
- if (remainData.length > 1) {
- ArrayList indexArray = new ArrayList<>();
- for (int i = 1; i < remainData.length; i++) {
- indexArray.add(remainData[i][0]);
- }
- node.setDataIndex(indexArray);
- }
- System.out.println("attr remain null");
- return;
- }
-
- for (String str : remainAttr) {
- giniArray = computeAttrGini(remainData, str);
- tempGini = Double.parseDouble(giniArray[1]);
-
- if (tempGini < minGini) {
- spiltAttrName = str;
- minGini = tempGini;
- valueType = giniArray[0];
- }
- }
- // 移除划分属性
- remainAttr.remove(spiltAttrName);
- node.setAttrName(spiltAttrName);
-
- // 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点
- AttrNode[] childNode = new AttrNode[2];
- String[][] rData;
-
- boolean[] bArray = new boolean[] { true, false };
- for (int i = 0; i < bArray.length; i++) {
- // 二元划分属于属性值的划分
- rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);
-
- boolean sameClass = true;
- ArrayList indexArray = new ArrayList<>();
- for (int k = 1; k < rData.length; k++) {
- indexArray.add(rData[k][0]);
- // 判断是否为同一类的
- if (!rData[k][attrNames.length - 1]
- .equals(rData[1][attrNames.length - 1])) {
- // 只要有1个不相等,就不是同类型的
- sameClass = false;
- break;
- }
- }
-
- childNode[i] = new AttrNode();
- if (!sameClass) {
- // 创建新的对象属性,对象的同个引用会出错
- ArrayList rAttr = new ArrayList<>();
- for (String str : remainAttr) {
- rAttr.add(str);
- }
- buildDecisionTree(childNode[i], valueType, rData, rAttr,
- bArray[i]);
- } else {
- String pAtr = (bArray[i] ? valueType : "!" + valueType);
- childNode[i].setParentAttrValue(pAtr);
- childNode[i].setDataIndex(indexArray);
- }
- }
-
- node.setChildAttrNode(childNode);
- }
-
- /**
- * 属性划分完毕,进行数据的移除
- *
- * @param srcData
- * 源数据
- * @param attrName
- * 划分的属性名称
- * @param valueType
- * 属性的值类型
- * @parame beLongValue 分类是否属于此值类型
- */
- private String[][] removeData(String[][] srcData, String attrName,
- String valueType, boolean beLongValue) {
- String[][] desDataArray;
- ArrayList desData = new ArrayList<>();
- // 待删除数据
- ArrayList selectData = new ArrayList<>();
- selectData.add(attrNames);
-
- // 数组数据转化到列表中,方便移除
- for (int i = 0; i < srcData.length; i++) {
- desData.add(srcData[i]);
- }
-
- // 还是从左往右一列列的查找
- for (int j = 1; j < attrNames.length; j++) {
- if (attrNames[j].equals(attrName)) {
- for (int i = 1; i < desData.size(); i++) {
- if (desData.get(i)[j].equals(valueType)) {
- // 如果匹配这个数据,则移除其他的数据
- selectData.add(desData.get(i));
- }
- }
- }
- }
-
- if (beLongValue) {
- desDataArray = new String[selectData.size()][];
- selectData.toArray(desDataArray);
- } else {
- // 属性名称行不移除
- selectData.remove(attrNames);
- // 如果是划分不属于此类型的数据时,进行移除
- desData.removeAll(selectData);
- desDataArray = new String[desData.size()][];
- desData.toArray(desDataArray);
- }
-
- return desDataArray;
- }
-
- public void startBuildingTree() {
- readDataFile();
- initAttrValue();
-
- ArrayList remainAttr = new ArrayList<>();
- // 添加属性,除了最后一个类标号属性
- for (int i = 1; i < attrNames.length - 1; i++) {
- remainAttr.add(attrNames[i]);
- }
-
- AttrNode rootNode = new AttrNode();
- buildDecisionTree(rootNode, "", data, remainAttr, false);
- setIndexAndAlpah(rootNode, 0, false);
- System.out.println("剪枝前:");
- showDecisionTree(rootNode, 1);
- setIndexAndAlpah(rootNode, 0, true);
- System.out.println("\n剪枝后:");
- showDecisionTree(rootNode, 1);
- }
-
- /**
- * 显示决策树
- *
- * @param node
- * 待显示的节点
- * @param blankNum
- * 行空格符,用于显示树型结构
- */
- private void showDecisionTree(AttrNode node, int blankNum) {
- System.out.println();
- for (int i = 0; i < blankNum; i++) {
- System.out.print(" ");
- }
- System.out.print("--");
- // 显示分类的属性值
- if (node.getParentAttrValue() != null
- && node.getParentAttrValue().length() > 0) {
- System.out.print(node.getParentAttrValue());
- } else {
- System.out.print("--");
- }
- System.out.print("--");
-
- if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
- String i = node.getDataIndex().get(0);
- System.out.print("【" + node.getNodeIndex() + "】类别:"
- + data[Integer.parseInt(i)][attrNames.length - 1]);
- System.out.print("[");
- for (String index : node.getDataIndex()) {
- System.out.print(index + ", ");
- }
- System.out.print("]");
- } else {
- // 递归显示子节点
- System.out.print("【" + node.getNodeIndex() + ":"
- + node.getAttrName() + "】");
- if (node.getChildAttrNode() != null) {
- for (AttrNode childNode : node.getChildAttrNode()) {
- showDecisionTree(childNode, 2 * blankNum);
- }
- } else {
- System.out.print("【 Child Null】");
- }
- }
- }
-
- /**
- * 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝
- *
- * @param node
- * 开始的时候传入的是根节点
- * @param index
- * 开始的索引号,从1开始
- * @param ifCutNode
- * 是否需要剪枝
- */
- private void setIndexAndAlpah(AttrNode node, int index, boolean ifCutNode) {
- AttrNode tempNode;
- // 最小误差代价节点,即将被剪枝的节点
- AttrNode minAlphaNode = null;
- double minAlpah = Integer.MAX_VALUE;
- Queue nodeQueue = new LinkedList();
-
- nodeQueue.add(node);
- while (nodeQueue.size() > 0) {
- index++;
- // 从队列头部获取首个节点
- tempNode = nodeQueue.poll();
- tempNode.setNodeIndex(index);
- if (tempNode.getChildAttrNode() != null) {
- for (AttrNode childNode : tempNode.getChildAttrNode()) {
- nodeQueue.add(childNode);
- }
- computeAlpha(tempNode);
- if (tempNode.getAlpha() < minAlpah) {
- minAlphaNode = tempNode;
- minAlpah = tempNode.getAlpha();
- } else if (tempNode.getAlpha() == minAlpah) {
- // 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点
- if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) {
- minAlphaNode = tempNode;
- }
- }
- }
- }
-
- if (ifCutNode) {
- // 进行树的剪枝,让其左右孩子节点为null
- minAlphaNode.setChildAttrNode(null);
- }
- }
-
- /**
- * 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝
- *
- * @param node
- * 待计算的非叶子节点
- */
- private void computeAlpha(AttrNode node) {
- double rt = 0;
- double Rt = 0;
- double alpha = 0;
- // 当前节点的数据总数
- int sumNum = 0;
- // 最少的偏差数
- int minNum = 0;
-
- ArrayList dataIndex;
- ArrayList leafNodes = new ArrayList<>();
-
- addLeafNode(node, leafNodes);
- node.setLeafNum(leafNodes.size());
- for (AttrNode attrNode : leafNodes) {
- dataIndex = attrNode.getDataIndex();
-
- int num = 0;
- sumNum += dataIndex.size();
- for (String s : dataIndex) {
- // 统计分类数据中的正负实例数
- if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) {
- num++;
- }
- }
- minNum += num;
-
- // 取小数量的值部分
- if (1.0 * num / dataIndex.size() > 0.5) {
- num = dataIndex.size() - num;
- }
-
- rt += (1.0 * num / (data.length - 1));
- }
-
- //同样取出少偏差的那部分
- if (1.0 * minNum / sumNum > 0.5) {
- minNum = sumNum - minNum;
- }
-
- Rt = 1.0 * minNum / (data.length - 1);
- alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);
- node.setAlpha(alpha);
- }
-
- /**
- * 筛选出节点所包含的叶子节点数
- *
- * @param node
- * 待筛选节点
- * @param leafNode
- * 叶子节点列表容器
- */
- private void addLeafNode(AttrNode node, ArrayList leafNode) {
- ArrayList dataIndex;
-
- if (node.getChildAttrNode() != null) {
- for (AttrNode childNode : node.getChildAttrNode()) {
- dataIndex = childNode.getDataIndex();
- if (dataIndex != null && dataIndex.size() > 0) {
- // 说明此节点为叶子节点
- leafNode.add(childNode);
- } else {
- // 如果还是非叶子节点则继续递归调用
- addLeafNode(childNode, leafNode);
- }
- }
- }
- }
+ }
+
+ /**
+ * 计算机基尼指数
+ *
+ * @param remainData 剩余数据
+ * @param attrName 属性名称
+ * @param value 属性值
+ * @param beLongValue 分类是否属于此属性值
+ */
+ private double computeGini(String[][] remainData, String attrName,
+ String value, boolean beLongValue){
+ // 实例总数
+ int total;
+ // 正实例数
+ int posNum = 0;
+ // 负实例数
+ int negNum = 0;
+ // 基尼指数
+ double gini;
+
+ // 还是按列从左往右遍历属性
+ for (int j = 1; j < attrNames.length; j++) {
+ // 找到了指定的属性
+ if (attrName.equals(attrNames[j])) {
+ for (int i = 1; i < remainData.length; i++) {
+ // 统计正负实例按照属于和不属于值类型进行划分
+ if ((beLongValue && remainData[i][j].equals(value))
+ || (!beLongValue && !remainData[i][j].equals(value))) {
+ if (remainData[i][attrNames.length - 1].equals(YES)) {
+ // 判断此行数据是否为正实例
+ posNum++;
+ } else {
+ negNum++;
+ }
+ }
+ }
+ }
+ }
+
+ total = posNum + negNum;
+ double posProbobly = (double) posNum / total;
+ double negProbobly = (double) negNum / total;
+ gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;
+
+ // 返回计算基尼指数
+ return gini;
+ }
+
+ /**
+ * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中
+ *
+ * @param remainData 剩余谁
+ * @param attrName 属性名称
+ */
+ private String[] computeAttrGini(String[][] remainData, String attrName){
+ String[] str = new String[2];
+ // 最终该属性的划分类型值
+ String spiltValue = "";
+ // 临时变量
+ int tempNum;
+ // 保存属性的值划分时的最小的基尼指数
+ double minGini = Integer.MAX_VALUE;
+ ArrayList valueTypes = attrValue.get(attrName);
+ // 属于此属性值的实例数
+ HashMap belongNum = new HashMap<>();
+
+ for (String string : valueTypes) {
+ // 重新计数的时候,数字归0
+ tempNum = 0;
+ // 按列从左往右遍历属性
+ for (int j = 1; j < attrNames.length; j++) {
+ // 找到了指定的属性
+ if (attrName.equals(attrNames[j])) {
+ for (int i = 1; i < remainData.length; i++) {
+ // 统计正负实例按照属于和不属于值类型进行划分
+ if (remainData[i][j].equals(string)) {
+ tempNum++;
+ }
+ }
+ }
+ }
+
+ belongNum.put(string, tempNum);
+ }
+
+ double tempGini;
+ double posProbably;
+ double negProbably;
+ for (String string : valueTypes) {
+ tempGini = 0;
+
+ posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);
+ negProbably = 1 - posProbably;
+
+ tempGini += posProbably
+ * computeGini(remainData, attrName, string, true);
+ tempGini += negProbably
+ * computeGini(remainData, attrName, string, false);
+
+ if (tempGini < minGini) {
+ minGini = tempGini;
+ spiltValue = string;
+ }
+ }
+
+ str[0] = spiltValue;
+ str[1] = minGini + "";
+
+ return str;
+ }
+
+ private void buildDecisionTree(AttrNode node, String parentAttrValue,
+ String[][] remainData, ArrayList remainAttr,
+ boolean beLongParentValue){
+ // 属性划分值
+ String valueType = "";
+ // 划分属性名称
+ String spiltAttrName = "";
+ double minGini = Integer.MAX_VALUE;
+ double tempGini;
+ // 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值
+ String[] giniArray;
+
+ if (beLongParentValue) {
+ node.setParentAttrValue(parentAttrValue);
+ } else {
+ node.setParentAttrValue("!" + parentAttrValue);
+ }
+
+ if (remainAttr.size() == 0) {
+ if (remainData.length > 1) {
+ ArrayList indexArray = new ArrayList<>();
+ for (int i = 1; i < remainData.length; i++) {
+ indexArray.add(remainData[i][0]);
+ }
+ node.setDataIndex(indexArray);
+ }
+ System.out.println("attr remain null");
+ return;
+ }
+
+ for (String str : remainAttr) {
+ giniArray = computeAttrGini(remainData, str);
+ tempGini = Double.parseDouble(giniArray[1]);
+
+ if (tempGini < minGini) {
+ spiltAttrName = str;
+ minGini = tempGini;
+ valueType = giniArray[0];
+ }
+ }
+ // 移除划分属性
+ remainAttr.remove(spiltAttrName);
+ node.setAttrName(spiltAttrName);
+
+ // 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点
+ AttrNode[] childNode = new AttrNode[2];
+ String[][] rData;
+
+ boolean[] bArray = new boolean[]{true, false};
+ for (int i = 0; i < bArray.length; i++) {
+ // 二元划分属于属性值的划分
+ rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);
+
+ boolean sameClass = true;
+ ArrayList indexArray = new ArrayList<>();
+ for (int k = 1; k < rData.length; k++) {
+ indexArray.add(rData[k][0]);
+ // 判断是否为同一类的
+ if (!rData[k][attrNames.length - 1]
+ .equals(rData[1][attrNames.length - 1])) {
+ // 只要有1个不相等,就不是同类型的
+ sameClass = false;
+ break;
+ }
+ }
+
+ childNode[i] = new AttrNode();
+ if (!sameClass) {
+ // 创建新的对象属性,对象的同个引用会出错
+ ArrayList rAttr = new ArrayList<>();
+ for (String str : remainAttr) {
+ rAttr.add(str);
+ }
+ buildDecisionTree(childNode[i], valueType, rData, rAttr,
+ bArray[i]);
+ } else {
+ String pAtr = (bArray[i] ? valueType : "!" + valueType);
+ childNode[i].setParentAttrValue(pAtr);
+ childNode[i].setDataIndex(indexArray);
+ }
+ }
+
+ node.setChildAttrNode(childNode);
+ }
+
+ /**
+ * 属性划分完毕,进行数据的移除
+ *
+ * @param srcData 源数据
+ * @param attrName 划分的属性名称
+ * @param valueType 属性的值类型
+ * @param beLongValue 分类是否属于此值类型
+ */
+ private String[][] removeData(String[][] srcData, String attrName,
+ String valueType, boolean beLongValue){
+ String[][] desDataArray;
+ ArrayList desData = new ArrayList<>();
+ // 待删除数据
+ ArrayList selectData = new ArrayList<>();
+ selectData.add(attrNames);
+
+ // 数组数据转化到列表中,方便移除
+ Collections.addAll(desData, srcData);
+
+ // 还是从左往右一列列的查找
+ for (int j = 1; j < attrNames.length; j++) {
+ if (attrNames[j].equals(attrName)) {
+ for (int i = 1; i < desData.size(); i++) {
+ if (desData.get(i)[j].equals(valueType)) {
+ // 如果匹配这个数据,则移除其他的数据
+ selectData.add(desData.get(i));
+ }
+ }
+ }
+ }
+
+ if (beLongValue) {
+ desDataArray = new String[selectData.size()][];
+ selectData.toArray(desDataArray);
+ } else {
+ // 属性名称行不移除
+ selectData.remove(attrNames);
+ // 如果是划分不属于此类型的数据时,进行移除
+ desData.removeAll(selectData);
+ desDataArray = new String[desData.size()][];
+ desData.toArray(desDataArray);
+ }
+
+ return desDataArray;
+ }
+
+ void startBuildingTree(){
+ readDataFile();
+ initAttrValue();
+
+ ArrayList remainAttr = new ArrayList<>();
+ // 添加属性,除了最后一个类标号属性
+ remainAttr.addAll(Arrays.asList(attrNames).subList(1, attrNames.length - 1));
+
+ AttrNode rootNode = new AttrNode();
+ buildDecisionTree(rootNode, "", data, remainAttr, false);
+ setIndexAndAlpah(rootNode, 0, false);
+ System.out.println("剪枝前:");
+ showDecisionTree(rootNode, 1);
+ setIndexAndAlpah(rootNode, 0, true);
+ System.out.println("\n剪枝后:");
+ showDecisionTree(rootNode, 1);
+ }
+
+ /**
+ * 显示决策树
+ *
+ * @param node 待显示的节点
+ * @param blankNum 行空格符,用于显示树型结构
+ */
+ private void showDecisionTree(AttrNode node, int blankNum){
+ System.out.println();
+ for (int i = 0; i < blankNum; i++) {
+ System.out.print(" ");
+ }
+ System.out.print("--");
+ // 显示分类的属性值
+ if (node.getParentAttrValue() != null
+ && node.getParentAttrValue().length() > 0) {
+ System.out.print(node.getParentAttrValue());
+ } else {
+ System.out.print("--");
+ }
+ System.out.print("--");
+
+ if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
+ String i = node.getDataIndex().get(0);
+ System.out.print("【" + node.getNodeIndex() + "】类别:"
+ + data[Integer.parseInt(i)][attrNames.length - 1]);
+ System.out.print("[");
+ for (String index : node.getDataIndex()) {
+ System.out.print(index + ", ");
+ }
+ System.out.print("]");
+ } else {
+ // 递归显示子节点
+ System.out.print("【" + node.getNodeIndex() + ":"
+ + node.getAttrName() + "】");
+ if (node.getChildAttrNode() != null) {
+ for (AttrNode childNode : node.getChildAttrNode()) {
+ showDecisionTree(childNode, 2 * blankNum);
+ }
+ } else {
+ System.out.print("【 Child Null】");
+ }
+ }
+ }
+
+ /**
+ * 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝
+ *
+ * @param node 开始的时候传入的是根节点
+ * @param index 开始的索引号,从1开始
+ * @param ifCutNode 是否需要剪枝
+ */
+ private void setIndexAndAlpah(AttrNode node, int index, boolean ifCutNode){
+ AttrNode tempNode;
+ // 最小误差代价节点,即将被剪枝的节点
+ AttrNode minAlphaNode = null;
+ double minAlpah = Integer.MAX_VALUE;
+ Queue nodeQueue = new LinkedList<>();
+
+ nodeQueue.add(node);
+ while (nodeQueue.size() > 0) {
+ index++;
+ // 从队列头部获取首个节点
+ tempNode = nodeQueue.poll();
+ tempNode.setNodeIndex(index);
+ if (tempNode.getChildAttrNode() != null) {
+ Collections.addAll(nodeQueue, tempNode.getChildAttrNode());
+ computeAlpha(tempNode);
+ if (tempNode.getAlpha() < minAlpah) {
+ minAlphaNode = tempNode;
+ minAlpah = tempNode.getAlpha();
+ } else if (tempNode.getAlpha() == minAlpah) {
+ // 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点
+ if (minAlphaNode != null && tempNode.getLeafNum() > minAlphaNode.getLeafNum()) {
+ minAlphaNode = tempNode;
+ }
+ }
+ }
+ }
+
+ if (ifCutNode) {
+ // 进行树的剪枝,让其左右孩子节点为null
+ if (minAlphaNode != null) {
+ minAlphaNode.setChildAttrNode(null);
+ }
+ }
+ }
+
+ /**
+ * 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝
+ *
+ * @param node 待计算的非叶子节点
+ */
+ private void computeAlpha(AttrNode node){
+ double rt = 0;
+ double Rt;
+ double alpha;
+ // 当前节点的数据总数
+ int sumNum = 0;
+ // 最少的偏差数
+ int minNum = 0;
+
+ ArrayList dataIndex;
+ ArrayList leafNodes = new ArrayList<>();
+
+ addLeafNode(node, leafNodes);
+ node.setLeafNum(leafNodes.size());
+ for (AttrNode attrNode : leafNodes) {
+ dataIndex = attrNode.getDataIndex();
+
+ int num = 0;
+ sumNum += dataIndex.size();
+ for (String s : dataIndex) {
+ // 统计分类数据中的正负实例数
+ if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) {
+ num++;
+ }
+ }
+ minNum += num;
+
+ // 取小数量的值部分
+ if (1.0 * num / dataIndex.size() > 0.5) {
+ num = dataIndex.size() - num;
+ }
+
+ rt += (1.0 * num / (data.length - 1));
+ }
+
+ //同样取出少偏差的那部分
+ if (1.0 * minNum / sumNum > 0.5) {
+ minNum = sumNum - minNum;
+ }
+
+ Rt = 1.0 * minNum / (data.length - 1);
+ alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);
+ node.setAlpha(alpha);
+ }
+
+ /**
+ * 筛选出节点所包含的叶子节点数
+ *
+ * @param node 待筛选节点
+ * @param leafNode 叶子节点列表容器
+ */
+ private void addLeafNode(AttrNode node, ArrayList leafNode){
+ ArrayList dataIndex;
+
+ if (node.getChildAttrNode() != null) {
+ for (AttrNode childNode : node.getChildAttrNode()) {
+ dataIndex = childNode.getDataIndex();
+ if (dataIndex != null && dataIndex.size() > 0) {
+ // 说明此节点为叶子节点
+ leafNode.add(childNode);
+ } else {
+ // 如果还是非叶子节点则继续递归调用
+ addLeafNode(childNode, leafNode);
+ }
+ }
+ }
+ }
}
diff --git a/Classification/DataMining_CART/Client.java b/Classification/DataMining_CART/Client.java
index 0aabb0c..5c474a8 100644
--- a/Classification/DataMining_CART/Client.java
+++ b/Classification/DataMining_CART/Client.java
@@ -1,11 +1,11 @@
-package DataMining_CART;
+package Classification.DataMining_CART;
public class Client {
- public static void main(String[] args){
- String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
-
- CARTTool tool = new CARTTool(filePath);
-
- tool.startBuildingTree();
- }
+ public static void main(String[] args){
+ String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Classification/DataMining_CART/input.txt";
+
+ CARTTool tool = new CARTTool(filePath);
+
+ tool.startBuildingTree();
+ }
}
diff --git a/Classification/DataMining_ID3/AttrNode.java b/Classification/DataMining_ID3/AttrNode.java
index dfe2ce9..4018f81 100644
--- a/Classification/DataMining_ID3/AttrNode.java
+++ b/Classification/DataMining_ID3/AttrNode.java
@@ -1,51 +1,51 @@
-package DataMing_ID3;
+package Classification.DataMining_ID3;
import java.util.ArrayList;
/**
* 属性节点,不是叶子节点
- * @author lyq
*
+ * @author Qstar
*/
-public class AttrNode {
- //当前属性的名字
- private String attrName;
- //父节点的分类属性值
- private String parentAttrValue;
- //属性子节点
- private AttrNode[] childAttrNode;
- //孩子叶子节点
- private ArrayList childDataIndex;
-
- public String getAttrName() {
- return attrName;
- }
-
- public void setAttrName(String attrName) {
- this.attrName = attrName;
- }
-
- public AttrNode[] getChildAttrNode() {
- return childAttrNode;
- }
-
- public void setChildAttrNode(AttrNode[] childAttrNode) {
- this.childAttrNode = childAttrNode;
- }
-
- public String getParentAttrValue() {
- return parentAttrValue;
- }
-
- public void setParentAttrValue(String parentAttrValue) {
- this.parentAttrValue = parentAttrValue;
- }
-
- public ArrayList getChildDataIndex() {
- return childDataIndex;
- }
-
- public void setChildDataIndex(ArrayList childDataIndex) {
- this.childDataIndex = childDataIndex;
- }
+class AttrNode {
+ //当前属性的名字
+ private String attrName;
+ //父节点的分类属性值
+ private String parentAttrValue;
+ //属性子节点
+ private AttrNode[] childAttrNode;
+ //孩子叶子节点
+ private ArrayList childDataIndex;
+
+ public String getAttrName(){
+ return attrName;
+ }
+
+ public void setAttrName(String attrName){
+ this.attrName = attrName;
+ }
+
+ AttrNode[] getChildAttrNode(){
+ return childAttrNode;
+ }
+
+ void setChildAttrNode(AttrNode[] childAttrNode){
+ this.childAttrNode = childAttrNode;
+ }
+
+ String getParentAttrValue(){
+ return parentAttrValue;
+ }
+
+ void setParentAttrValue(String parentAttrValue){
+ this.parentAttrValue = parentAttrValue;
+ }
+
+ ArrayList getChildDataIndex(){
+ return childDataIndex;
+ }
+
+ void setChildDataIndex(ArrayList childDataIndex){
+ this.childDataIndex = childDataIndex;
+ }
}
diff --git a/Classification/DataMining_ID3/Client.java b/Classification/DataMining_ID3/Client.java
index 9b70211..47aca9c 100644
--- a/Classification/DataMining_ID3/Client.java
+++ b/Classification/DataMining_ID3/Client.java
@@ -1,15 +1,15 @@
-package DataMing_ID3;
+package Classification.DataMining_ID3;
/**
* ID3决策树分类算法测试场景类
- * @author lyq
*
+ * @author Qstar
*/
public class Client {
- public static void main(String[] args){
- String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
-
- ID3Tool tool = new ID3Tool(filePath);
- tool.startBuildingTree(true);
- }
+ public static void main(String[] args){
+ String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Classification/DataMining_ID3/input.txt";
+
+ ID3Tool tool = new ID3Tool(filePath);
+ tool.startBuildingTree(true);
+ }
}
diff --git a/Classification/DataMining_ID3/DataNode.java b/Classification/DataMining_ID3/DataNode.java
index e29ab51..a623706 100644
--- a/Classification/DataMining_ID3/DataNode.java
+++ b/Classification/DataMining_ID3/DataNode.java
@@ -1,17 +1,17 @@
-package DataMing_ID3;
+package Classification.DataMining_ID3;
/**
* 存放数据的叶子节点
- * @author lyq
*
+ * @author Qstar
*/
public class DataNode {
- /**
- * 数据的标号
- */
- private int dataIndex;
-
- public DataNode(int dataIndex){
- this.dataIndex = dataIndex;
- }
+ /**
+ * 数据的标号
+ */
+ private int dataIndex;
+
+ public DataNode(int dataIndex){
+ this.dataIndex = dataIndex;
+ }
}
diff --git a/Classification/DataMining_ID3/ID3Tool.java b/Classification/DataMining_ID3/ID3Tool.java
index 67cfad7..0284eb5 100644
--- a/Classification/DataMining_ID3/ID3Tool.java
+++ b/Classification/DataMining_ID3/ID3Tool.java
@@ -1,447 +1,422 @@
-package DataMing_ID3;
+package Classification.DataMining_ID3;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.Map.Entry;
-import java.util.Set;
/**
* ID3算法实现类
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class ID3Tool {
- // 类标号的值类型
- private final String YES = "Yes";
- private final String NO = "No";
-
- // 所有属性的类型总数,在这里就是data源数据的列数
- private int attrNum;
- private String filePath;
- // 初始源数据,用一个二维字符数组存放模仿表格数据
- private String[][] data;
- // 数据的属性行的名字
- private String[] attrNames;
- // 每个属性的值所有类型
- private HashMap> attrValue;
-
- public ID3Tool(String filePath) {
- this.filePath = filePath;
- attrValue = new HashMap<>();
- }
-
- /**
- * 从文件中读取数据
- */
- private void readDataFile() {
- File file = new File(filePath);
- ArrayList dataArray = new ArrayList();
-
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
-
- data = new String[dataArray.size()][];
- dataArray.toArray(data);
- attrNum = data[0].length;
- attrNames = data[0];
+class ID3Tool {
+ // 类标号的值类型
+ private final String YES = "Yes";
+ private final String NO = "No";
+
+ // 所有属性的类型总数,在这里就是data源数据的列数
+ private int attrNum;
+ private String filePath;
+ // 初始源数据,用一个二维字符数组存放模仿表格数据
+ private String[][] data;
+ // 数据的属性行的名字
+ private String[] attrNames;
+ // 每个属性的值所有类型
+ private HashMap> attrValue;
+
+ ID3Tool(String filePath){
+ this.filePath = filePath;
+ attrValue = new HashMap<>();
+ }
+
+ /**
+ * 从文件中读取数据
+ */
+ private void readDataFile(){
+ File file = new File(filePath);
+ ArrayList dataArray = new ArrayList<>();
+
+ try {
+ BufferedReader in = new BufferedReader(new FileReader(file));
+ String str;
+ String[] tempArray;
+ while ((str = in.readLine()) != null) {
+ tempArray = str.split(" ");
+ dataArray.add(tempArray);
+ }
+ in.close();
+ } catch (IOException e) {
+ e.getStackTrace();
+ }
+
+ data = new String[dataArray.size()][];
+ dataArray.toArray(data);
+ attrNum = data[0].length;
+ attrNames = data[0];
/*
- * for(int i=0; i tempValues;
-
- // 按照列的方式,从左往右找
- for (int j = 1; j < attrNum; j++) {
- // 从一列中的上往下开始寻找值
- tempValues = new ArrayList<>();
- for (int i = 1; i < data.length; i++) {
- if (!tempValues.contains(data[i][j])) {
- // 如果这个属性的值没有添加过,则添加
- tempValues.add(data[i][j]);
- }
- }
-
- // 一列属性的值已经遍历完毕,复制到map属性表中
- attrValue.put(data[0][j], tempValues);
- }
+ }
+
+ /**
+ * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
+ */
+ private void initAttrValue(){
+ ArrayList tempValues;
+
+ // 按照列的方式,从左往右找
+ for (int j = 1; j < attrNum; j++) {
+ // 从一列中的上往下开始寻找值
+ tempValues = new ArrayList<>();
+ for (int i = 1; i < data.length; i++) {
+ if (!tempValues.contains(data[i][j])) {
+ // 如果这个属性的值没有添加过,则添加
+ tempValues.add(data[i][j]);
+ }
+ }
+
+ // 一列属性的值已经遍历完毕,复制到map属性表中
+ attrValue.put(data[0][j], tempValues);
+ }
/*
- * for(Map.Entry entry : attrValue.entrySet()){
+ * for(Map.Entry entry : attrValue.entrySet()){
* System.out.println("key:value " + entry.getKey() + ":" +
* entry.getValue()); }
*/
- }
-
- /**
- * 计算数据按照不同方式划分的熵
- *
- * @param remainData
- * 剩余的数据
- * @param attrName
- * 待划分的属性,在算信息增益的时候会使用到
- * @param attrValue
- * 划分的子属性值
- * @param isParent
- * 是否分子属性划分还是原来不变的划分
- */
- private double computeEntropy(String[][] remainData, String attrName,
- String value, boolean isParent) {
- // 实例总数
- int total = 0;
- // 正实例数
- int posNum = 0;
- // 负实例数
- int negNum = 0;
-
- // 还是按列从左往右遍历属性
- for (int j = 1; j < attrNames.length; j++) {
- // 找到了指定的属性
- if (attrName.equals(attrNames[j])) {
- for (int i = 1; i < remainData.length; i++) {
- // 如果是父结点直接计算熵或者是通过子属性划分计算熵,这时要进行属性值的过滤
- if (isParent
- || (!isParent && remainData[i][j].equals(value))) {
- if (remainData[i][attrNames.length - 1].equals(YES)) {
- // 判断此行数据是否为正实例
- posNum++;
- } else {
- negNum++;
- }
- }
- }
- }
- }
-
- total = posNum + negNum;
- double posProbobly = (double) posNum / total;
- double negProbobly = (double) negNum / total;
-
- if (posProbobly == 1 || posProbobly == 0) {
- // 如果数据全为同种类型,则熵为0,否则带入下面的公式会报错
- return 0;
- }
-
- double entropyValue = -posProbobly * Math.log(posProbobly)
- / Math.log(2.0) - negProbobly * Math.log(negProbobly)
- / Math.log(2.0);
-
- // 返回计算所得熵
- return entropyValue;
- }
-
- /**
- * 为某个属性计算信息增益
- *
- * @param remainData
- * 剩余的数据
- * @param value
- * 待划分的属性名称
- * @return
- */
- private double computeGain(String[][] remainData, String value) {
- double gainValue = 0;
- // 源熵的大小将会与属性划分后进行比较
- double entropyOri = 0;
- // 子划分熵和
- double childEntropySum = 0;
- // 属性子类型的个数
- int childValueNum = 0;
- // 属性值的种数
- ArrayList attrTypes = attrValue.get(value);
- // 子属性对应的权重比
- HashMap ratioValues = new HashMap<>();
-
- for (int i = 0; i < attrTypes.size(); i++) {
- // 首先都统一计数为0
- ratioValues.put(attrTypes.get(i), 0);
- }
-
- // 还是按照一列,从左往右遍历
- for (int j = 1; j < attrNames.length; j++) {
- // 判断是否到了划分的属性列
- if (value.equals(attrNames[j])) {
- for (int i = 1; i <= remainData.length - 1; i++) {
- childValueNum = ratioValues.get(remainData[i][j]);
- // 增加个数并且重新存入
- childValueNum++;
- ratioValues.put(remainData[i][j], childValueNum);
- }
- }
- }
-
- // 计算原熵的大小
- entropyOri = computeEntropy(remainData, value, null, true);
- for (int i = 0; i < attrTypes.size(); i++) {
- double ratio = (double) ratioValues.get(attrTypes.get(i))
- / (remainData.length - 1);
- childEntropySum += ratio
- * computeEntropy(remainData, value, attrTypes.get(i), false);
-
- // System.out.println("ratio:value: " + ratio + " " +
- // computeEntropy(remainData, value,
- // attrTypes.get(i), false));
- }
-
- // 二者熵相减就是信息增益
- gainValue = entropyOri - childEntropySum;
- return gainValue;
- }
-
- /**
- * 计算信息增益比
- *
- * @param remainData
- * 剩余数据
- * @param value
- * 待划分属性
- * @return
- */
- private double computeGainRatio(String[][] remainData, String value) {
- double gain = 0;
- double spiltInfo = 0;
- int childValueNum = 0;
- // 属性值的种数
- ArrayList attrTypes = attrValue.get(value);
- // 子属性对应的权重比
- HashMap ratioValues = new HashMap<>();
-
- for (int i = 0; i < attrTypes.size(); i++) {
- // 首先都统一计数为0
- ratioValues.put(attrTypes.get(i), 0);
- }
-
- // 还是按照一列,从左往右遍历
- for (int j = 1; j < attrNames.length; j++) {
- // 判断是否到了划分的属性列
- if (value.equals(attrNames[j])) {
- for (int i = 1; i <= remainData.length - 1; i++) {
- childValueNum = ratioValues.get(remainData[i][j]);
- // 增加个数并且重新存入
- childValueNum++;
- ratioValues.put(remainData[i][j], childValueNum);
- }
- }
- }
-
- // 计算信息增益
- gain = computeGain(remainData, value);
- // 计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):
- for (int i = 0; i < attrTypes.size(); i++) {
- double ratio = (double) ratioValues.get(attrTypes.get(i))
- / (remainData.length - 1);
- spiltInfo += -ratio * Math.log(ratio) / Math.log(2.0);
- }
-
- // 计算机信息增益率
- return gain / spiltInfo;
- }
-
- /**
- * 利用源数据构造决策树
- */
- private void buildDecisionTree(AttrNode node, String parentAttrValue,
- String[][] remainData, ArrayList remainAttr, boolean isID3) {
- node.setParentAttrValue(parentAttrValue);
-
- String attrName = "";
- double gainValue = 0;
- double tempValue = 0;
-
- // 如果只有1个属性则直接返回
- if (remainAttr.size() == 1) {
- System.out.println("attr null");
- return;
- }
-
- // 选择剩余属性中信息增益最大的作为下一个分类的属性
- for (int i = 0; i < remainAttr.size(); i++) {
- // 判断是否用ID3算法还是C4.5算法
- if (isID3) {
- // ID3算法采用的是按照信息增益的值来比
- tempValue = computeGain(remainData, remainAttr.get(i));
- } else {
- // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
- tempValue = computeGainRatio(remainData, remainAttr.get(i));
- }
-
- if (tempValue > gainValue) {
- gainValue = tempValue;
- attrName = remainAttr.get(i);
- }
- }
-
- node.setAttrName(attrName);
- ArrayList valueTypes = attrValue.get(attrName);
- remainAttr.remove(attrName);
-
- AttrNode[] childNode = new AttrNode[valueTypes.size()];
- String[][] rData;
- for (int i = 0; i < valueTypes.size(); i++) {
- // 移除非此值类型的数据
- rData = removeData(remainData, attrName, valueTypes.get(i));
-
- childNode[i] = new AttrNode();
- boolean sameClass = true;
- ArrayList indexArray = new ArrayList<>();
- for (int k = 1; k < rData.length; k++) {
- indexArray.add(rData[k][0]);
- // 判断是否为同一类的
- if (!rData[k][attrNames.length - 1]
- .equals(rData[1][attrNames.length - 1])) {
- // 只要有1个不相等,就不是同类型的
- sameClass = false;
- break;
- }
- }
-
- if (!sameClass) {
- // 创建新的对象属性,对象的同个引用会出错
- ArrayList rAttr = new ArrayList<>();
- for (String str : remainAttr) {
- rAttr.add(str);
- }
-
- buildDecisionTree(childNode[i], valueTypes.get(i), rData,
- rAttr, isID3);
- } else {
- // 如果是同种类型,则直接为数据节点
- childNode[i].setParentAttrValue(valueTypes.get(i));
- childNode[i].setChildDataIndex(indexArray);
- }
-
- }
- node.setChildAttrNode(childNode);
- }
-
- /**
- * 属性划分完毕,进行数据的移除
- *
- * @param srcData
- * 源数据
- * @param attrName
- * 划分的属性名称
- * @param valueType
- * 属性的值类型
- */
- private String[][] removeData(String[][] srcData, String attrName,
- String valueType) {
- String[][] desDataArray;
- ArrayList desData = new ArrayList<>();
- // 待删除数据
- ArrayList selectData = new ArrayList<>();
- selectData.add(attrNames);
-
- // 数组数据转化到列表中,方便移除
- for (int i = 0; i < srcData.length; i++) {
- desData.add(srcData[i]);
- }
-
- // 还是从左往右一列列的查找
- for (int j = 1; j < attrNames.length; j++) {
- if (attrNames[j].equals(attrName)) {
- for (int i = 1; i < desData.size(); i++) {
- if (desData.get(i)[j].equals(valueType)) {
- // 如果匹配这个数据,则移除其他的数据
- selectData.add(desData.get(i));
- }
- }
- }
- }
-
- desDataArray = new String[selectData.size()][];
- selectData.toArray(desDataArray);
-
- return desDataArray;
- }
-
- /**
- * 开始构建决策树
- *
- * @param isID3
- * 是否采用ID3算法构架决策树
- */
- public void startBuildingTree(boolean isID3) {
- readDataFile();
- initAttrValue();
-
- ArrayList remainAttr = new ArrayList<>();
- // 添加属性,除了最后一个类标号属性
- for (int i = 1; i < attrNames.length - 1; i++) {
- remainAttr.add(attrNames[i]);
- }
-
- AttrNode rootNode = new AttrNode();
- buildDecisionTree(rootNode, "", data, remainAttr, isID3);
- showDecisionTree(rootNode, 1);
- }
-
- /**
- * 显示决策树
- *
- * @param node
- * 待显示的节点
- * @param blankNum
- * 行空格符,用于显示树型结构
- */
- private void showDecisionTree(AttrNode node, int blankNum) {
- System.out.println();
- for (int i = 0; i < blankNum; i++) {
- System.out.print("\t");
- }
- System.out.print("--");
- // 显示分类的属性值
- if (node.getParentAttrValue() != null
- && node.getParentAttrValue().length() > 0) {
- System.out.print(node.getParentAttrValue());
- } else {
- System.out.print("--");
- }
- System.out.print("--");
-
- if (node.getChildDataIndex() != null
- && node.getChildDataIndex().size() > 0) {
- String i = node.getChildDataIndex().get(0);
- System.out.print("类别:"
- + data[Integer.parseInt(i)][attrNames.length - 1]);
- System.out.print("[");
- for (String index : node.getChildDataIndex()) {
- System.out.print(index + ", ");
- }
- System.out.print("]");
- } else {
- // 递归显示子节点
- System.out.print("【" + node.getAttrName() + "】");
- for (AttrNode childNode : node.getChildAttrNode()) {
- showDecisionTree(childNode, 2 * blankNum);
- }
- }
-
- }
+ }
+
+ /**
+ * 计算数据按照不同方式划分的熵
+ *
+ * @param remainData 剩余的数据
+ * @param attrName 待划分的属性,在算信息增益的时候会使用到
+ * @param value 划分的子属性值
+ * @param isParent 是否分子属性划分还是原来不变的划分
+ */
+ private double computeEntropy(String[][] remainData, String attrName,
+ String value, boolean isParent){
+ // 实例总数
+ int total;
+ // 正实例数
+ int posNum = 0;
+ // 负实例数
+ int negNum = 0;
+
+ // 还是按列从左往右遍历属性
+ for (int j = 1; j < attrNames.length; j++) {
+ // 找到了指定的属性
+ if (attrName.equals(attrNames[j])) {
+ for (int i = 1; i < remainData.length; i++) {
+ // 如果是父结点直接计算熵或者是通过子属性划分计算熵,这时要进行属性值的过滤
+ if (isParent
+ || (!isParent && remainData[i][j].equals(value))) {
+ if (remainData[i][attrNames.length - 1].equals(YES)) {
+ // 判断此行数据是否为正实例
+ posNum++;
+ } else {
+ negNum++;
+ }
+ }
+ }
+ }
+ }
+
+ total = posNum + negNum;
+ double posProbobly = (double) posNum / total;
+ double negProbobly = (double) negNum / total;
+
+ if (posProbobly == 1 || posProbobly == 0) {
+ // 如果数据全为同种类型,则熵为0,否则带入下面的公式会报错
+ return 0;
+ }
+
+ // 返回计算所得熵
+ return -posProbobly * Math.log(posProbobly)
+ / Math.log(2.0) - negProbobly * Math.log(negProbobly)
+ / Math.log(2.0);
+ }
+
+ /**
+ * 为某个属性计算信息增益
+ *
+ * @param remainData 剩余的数据
+ * @param value 待划分的属性名称
+ */
+ private double computeGain(String[][] remainData, String value){
+ double gainValue;
+ // 源熵的大小将会与属性划分后进行比较
+ double entropyOri;
+ // 子划分熵和
+ double childEntropySum = 0;
+ // 属性子类型的个数
+ int childValueNum;
+ // 属性值的种数
+ ArrayList attrTypes = attrValue.get(value);
+ // 子属性对应的权重比
+ HashMap ratioValues = new HashMap<>();
+
+ for (String attrType : attrTypes) {
+ // 首先都统一计数为0
+ ratioValues.put(attrType, 0);
+ }
+
+ // 还是按照一列,从左往右遍历
+ for (int j = 1; j < attrNames.length; j++) {
+ // 判断是否到了划分的属性列
+ if (value.equals(attrNames[j])) {
+ for (int i = 1; i <= remainData.length - 1; i++) {
+ childValueNum = ratioValues.get(remainData[i][j]);
+ // 增加个数并且重新存入
+ childValueNum++;
+ ratioValues.put(remainData[i][j], childValueNum);
+ }
+ }
+ }
+
+ // 计算原熵的大小
+ entropyOri = computeEntropy(remainData, value, null, true);
+ for (String attrType : attrTypes) {
+ double ratio = (double) ratioValues.get(attrType)
+ / (remainData.length - 1);
+ childEntropySum += ratio
+ * computeEntropy(remainData, value, attrType, false);
+
+ // System.out.println("ratio:value: " + ratio + " " +
+ // computeEntropy(remainData, value,
+ // attrTypes.get(i), false));
+ }
+
+ // 二者熵相减就是信息增益
+ gainValue = entropyOri - childEntropySum;
+ return gainValue;
+ }
+
+ /**
+ * 计算信息增益比
+ *
+ * @param remainData 剩余数据
+ * @param value 待划分属性
+ */
+ private double computeGainRatio(String[][] remainData, String value){
+ double gain;
+ double spiltInfo = 0;
+ int childValueNum;
+ // 属性值的种数
+ ArrayList attrTypes = attrValue.get(value);
+ // 子属性对应的权重比
+ HashMap ratioValues = new HashMap<>();
+
+ for (String attrType : attrTypes) {
+ // 首先都统一计数为0
+ ratioValues.put(attrType, 0);
+ }
+
+ // 还是按照一列,从左往右遍历
+ for (int j = 1; j < attrNames.length; j++) {
+ // 判断是否到了划分的属性列
+ if (value.equals(attrNames[j])) {
+ for (int i = 1; i <= remainData.length - 1; i++) {
+ childValueNum = ratioValues.get(remainData[i][j]);
+ // 增加个数并且重新存入
+ childValueNum++;
+ ratioValues.put(remainData[i][j], childValueNum);
+ }
+ }
+ }
+
+ // 计算信息增益
+ gain = computeGain(remainData, value);
+ // 计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):
+ for (String attrType : attrTypes) {
+ double ratio = (double) ratioValues.get(attrType)
+ / (remainData.length - 1);
+ spiltInfo += -ratio * Math.log(ratio) / Math.log(2.0);
+ }
+
+ // 计算机信息增益率
+ return gain / spiltInfo;
+ }
+
+ /**
+ * 利用源数据构造决策树
+ */
+ private void buildDecisionTree(AttrNode node, String parentAttrValue,
+ String[][] remainData, ArrayList remainAttr, boolean isID3){
+ node.setParentAttrValue(parentAttrValue);
+
+ String attrName = "";
+ double gainValue = 0;
+ double tempValue;
+
+ // 如果只有1个属性则直接返回
+ if (remainAttr.size() == 1) {
+ System.out.println("attr null");
+ return;
+ }
+
+ // 选择剩余属性中信息增益最大的作为下一个分类的属性
+ for (String aRemainAttr : remainAttr) {
+ // 判断是否用ID3算法还是C4.5算法
+ if (isID3) {
+ // ID3算法采用的是按照信息增益的值来比
+ tempValue = computeGain(remainData, aRemainAttr);
+ } else {
+ // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
+ tempValue = computeGainRatio(remainData, aRemainAttr);
+ }
+
+ if (tempValue > gainValue) {
+ gainValue = tempValue;
+ attrName = aRemainAttr;
+ }
+ }
+
+ node.setAttrName(attrName);
+ ArrayList valueTypes = attrValue.get(attrName);
+ remainAttr.remove(attrName);
+
+ AttrNode[] childNode = new AttrNode[valueTypes.size()];
+ String[][] rData;
+ for (int i = 0; i < valueTypes.size(); i++) {
+ // 移除非此值类型的数据
+ rData = removeData(remainData, attrName, valueTypes.get(i));
+
+ childNode[i] = new AttrNode();
+ boolean sameClass = true;
+ ArrayList indexArray = new ArrayList<>();
+ for (int k = 1; k < rData.length; k++) {
+ indexArray.add(rData[k][0]);
+ // 判断是否为同一类的
+ if (!rData[k][attrNames.length - 1]
+ .equals(rData[1][attrNames.length - 1])) {
+ // 只要有1个不相等,就不是同类型的
+ sameClass = false;
+ break;
+ }
+ }
+
+ if (!sameClass) {
+ // 创建新的对象属性,对象的同个引用会出错
+ ArrayList rAttr = new ArrayList<>();
+ for (String str : remainAttr) {
+ rAttr.add(str);
+ }
+
+ buildDecisionTree(childNode[i], valueTypes.get(i), rData,
+ rAttr, isID3);
+ } else {
+ // 如果是同种类型,则直接为数据节点
+ childNode[i].setParentAttrValue(valueTypes.get(i));
+ childNode[i].setChildDataIndex(indexArray);
+ }
+
+ }
+ node.setChildAttrNode(childNode);
+ }
+
+ /**
+ * 属性划分完毕,进行数据的移除
+ *
+ * @param srcData 源数据
+ * @param attrName 划分的属性名称
+ * @param valueType 属性的值类型
+ */
+ private String[][] removeData(String[][] srcData, String attrName,
+ String valueType){
+ String[][] desDataArray;
+ ArrayList desData = new ArrayList<>();
+ // 待删除数据
+ ArrayList selectData = new ArrayList<>();
+ selectData.add(attrNames);
+
+ // 数组数据转化到列表中,方便移除
+ Collections.addAll(desData, srcData);
+
+ // 还是从左往右一列列的查找
+ for (int j = 1; j < attrNames.length; j++) {
+ if (attrNames[j].equals(attrName)) {
+ for (int i = 1; i < desData.size(); i++) {
+ if (desData.get(i)[j].equals(valueType)) {
+ // 如果匹配这个数据,则移除其他的数据
+ selectData.add(desData.get(i));
+ }
+ }
+ }
+ }
+
+ desDataArray = new String[selectData.size()][];
+ selectData.toArray(desDataArray);
+
+ return desDataArray;
+ }
+
+ /**
+ * 开始构建决策树
+ *
+ * @param isID3 是否采用ID3算法构架决策树
+ */
+ void startBuildingTree(boolean isID3){
+ readDataFile();
+ initAttrValue();
+
+ ArrayList remainAttr = new ArrayList<>();
+ // 添加属性,除了最后一个类标号属性
+ remainAttr.addAll(Arrays.asList(attrNames).subList(1, attrNames.length - 1));
+
+ AttrNode rootNode = new AttrNode();
+ buildDecisionTree(rootNode, "", data, remainAttr, isID3);
+ showDecisionTree(rootNode, 1);
+ }
+
+ /**
+ * 显示决策树
+ *
+ * @param node 待显示的节点
+ * @param blankNum 行空格符,用于显示树型结构
+ */
+ private void showDecisionTree(AttrNode node, int blankNum){
+ System.out.println();
+ for (int i = 0; i < blankNum; i++) {
+ System.out.print("\t");
+ }
+ System.out.print("--");
+ // 显示分类的属性值
+ if (node.getParentAttrValue() != null
+ && node.getParentAttrValue().length() > 0) {
+ System.out.print(node.getParentAttrValue());
+ } else {
+ System.out.print("--");
+ }
+ System.out.print("--");
+
+ if (node.getChildDataIndex() != null
+ && node.getChildDataIndex().size() > 0) {
+ String i = node.getChildDataIndex().get(0);
+ System.out.print("类别:"
+ + data[Integer.parseInt(i)][attrNames.length - 1]);
+ System.out.print("[");
+ for (String index : node.getChildDataIndex()) {
+ System.out.print(index + ", ");
+ }
+ System.out.print("]");
+ } else {
+ // 递归显示子节点
+ System.out.print("【" + node.getAttrName() + "】");
+ for (AttrNode childNode : node.getChildAttrNode()) {
+ showDecisionTree(childNode, 2 * blankNum);
+ }
+ }
+
+ }
}
diff --git a/Classification/DataMining_KNN/Client.java b/Classification/DataMining_KNN/Client.java
index 03e1d5a..51f8574 100644
--- a/Classification/DataMining_KNN/Client.java
+++ b/Classification/DataMining_KNN/Client.java
@@ -1,26 +1,16 @@
-package DataMining_KNN;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.List;
-
+package Classification.DataMining_KNN;
/**
* k最近邻算法场景类型
- * @author lyq
*
+ * @author Qstar
*/
public class Client {
- public static void main(String[] args){
- String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt";
- String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testinput.txt";
-
- KNNTool tool = new KNNTool(trainDataPath, testDataPath);
- tool.knnCompute(3);
-
- }
-
-
+ public static void main(String[] args){
+ String trainDataPath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Classification/DataMining_KNN/trainInput.txt";
+ String testDataPath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Classification/DataMining_KNN/testinput.txt";
+ KNNTool tool = new KNNTool(trainDataPath, testDataPath);
+ tool.knnCompute(3);
+ }
}
diff --git a/Classification/DataMining_KNN/KNNTool.java b/Classification/DataMining_KNN/KNNTool.java
index bb316fb..af24af5 100644
--- a/Classification/DataMining_KNN/KNNTool.java
+++ b/Classification/DataMining_KNN/KNNTool.java
@@ -1,200 +1,186 @@
-package DataMining_KNN;
+package Classification.DataMining_KNN;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
import java.util.Collections;
-import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
-import org.apache.activemq.filter.ComparisonExpression;
-
/**
* k最近邻算法工具类
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class KNNTool {
- // 为4个类别设置权重,默认权重比一致
- public int[] classWeightArray = new int[] { 1, 1, 1, 1 };
- // 测试数据
- private String testDataPath;
- // 训练集数据地址
- private String trainDataPath;
- // 分类的不同类型
- private ArrayList classTypes;
- // 结果数据
- private ArrayList resultSamples;
- // 训练集数据列表容器
- private ArrayList trainSamples;
- // 训练集数据
- private String[][] trainData;
- // 测试集数据
- private String[][] testData;
-
- public KNNTool(String trainDataPath, String testDataPath) {
- this.trainDataPath = trainDataPath;
- this.testDataPath = testDataPath;
- readDataFormFile();
- }
-
- /**
- * 从文件中阅读测试数和训练数据集
- */
- private void readDataFormFile() {
- ArrayList tempArray;
-
- tempArray = fileDataToArray(trainDataPath);
- trainData = new String[tempArray.size()][];
- tempArray.toArray(trainData);
-
- classTypes = new ArrayList<>();
- for (String[] s : tempArray) {
- if (!classTypes.contains(s[0])) {
- // 添加类型
- classTypes.add(s[0]);
- }
- }
-
- tempArray = fileDataToArray(testDataPath);
- testData = new String[tempArray.size()][];
- tempArray.toArray(testData);
- }
-
- /**
- * 将文件转为列表数据输出
- *
- * @param filePath
- * 数据文件的内容
- */
- private ArrayList fileDataToArray(String filePath) {
- File file = new File(filePath);
- ArrayList dataArray = new ArrayList();
-
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
-
- return dataArray;
- }
-
- /**
- * 计算样本特征向量的欧几里得距离
- *
- * @param f1
- * 待比较样本1
- * @param f2
- * 待比较样本2
- * @return
- */
- private int computeEuclideanDistance(Sample s1, Sample s2) {
- String[] f1 = s1.getFeatures();
- String[] f2 = s2.getFeatures();
- // 欧几里得距离
- int distance = 0;
-
- for (int i = 0; i < f1.length; i++) {
- int subF1 = Integer.parseInt(f1[i]);
- int subF2 = Integer.parseInt(f2[i]);
-
- distance += (subF1 - subF2) * (subF1 - subF2);
- }
-
- return distance;
- }
-
- /**
- * 计算K最近邻
- * @param k
- * 在多少的k范围内
- */
- public void knnCompute(int k) {
- String className = "";
- String[] tempF = null;
- Sample temp;
- resultSamples = new ArrayList<>();
- trainSamples = new ArrayList<>();
- // 分类类别计数
- HashMap classCount;
- // 类别权重比
- HashMap classWeight = new HashMap<>();
- // 首先讲测试数据转化到结果数据中
- for (String[] s : testData) {
- temp = new Sample(s);
- resultSamples.add(temp);
- }
-
- for (String[] s : trainData) {
- className = s[0];
- tempF = new String[s.length - 1];
- System.arraycopy(s, 1, tempF, 0, s.length - 1);
- temp = new Sample(className, tempF);
- trainSamples.add(temp);
- }
-
- // 离样本最近排序的的训练集数据
- ArrayList kNNSample = new ArrayList<>();
- // 计算训练数据集中离样本数据最近的K个训练集数据
- for (Sample s : resultSamples) {
- classCount = new HashMap<>();
- int index = 0;
- for (String type : classTypes) {
- // 开始时计数为0
- classCount.put(type, 0);
- classWeight.put(type, classWeightArray[index++]);
- }
- for (Sample tS : trainSamples) {
- int dis = computeEuclideanDistance(s, tS);
- tS.setDistance(dis);
- }
-
- Collections.sort(trainSamples);
- kNNSample.clear();
- // 挑选出前k个数据作为分类标准
- for (int i = 0; i < trainSamples.size(); i++) {
- if (i < k) {
- kNNSample.add(trainSamples.get(i));
- } else {
- break;
- }
- }
- // 判定K个训练数据的多数的分类标准
- for (Sample s1 : kNNSample) {
- int num = classCount.get(s1.getClassName());
- // 进行分类权重的叠加,默认类别权重平等,可自行改变,近的权重大,远的权重小
- num += classWeight.get(s1.getClassName());
- classCount.put(s1.getClassName(), num);
- }
-
- int maxCount = 0;
- // 筛选出k个训练集数据中最多的一个分类
- for (Map.Entry entry : classCount.entrySet()) {
- if ((Integer) entry.getValue() > maxCount) {
- maxCount = (Integer) entry.getValue();
- s.setClassName((String) entry.getKey());
- }
- }
-
- System.out.print("测试数据特征:");
- for (String s1 : s.getFeatures()) {
- System.out.print(s1 + " ");
- }
- System.out.println("分类:" + s.getClassName());
- }
- }
+class KNNTool {
+ // 为4个类别设置权重,默认权重比一致
+ private int[] classWeightArray = new int[]{1, 1, 1, 1};
+ // 测试数据
+ private String testDataPath;
+ // 训练集数据地址
+ private String trainDataPath;
+ // 分类的不同类型
+ private ArrayList classTypes;
+ // 训练集数据
+ private String[][] trainData;
+ // 测试集数据
+ private String[][] testData;
+
+ KNNTool(String trainDataPath, String testDataPath){
+ this.trainDataPath = trainDataPath;
+ this.testDataPath = testDataPath;
+ readDataFormFile();
+ }
+
+ /**
+ * 从文件中阅读测试数和训练数据集
+ */
+ private void readDataFormFile(){
+ ArrayList tempArray;
+
+ tempArray = fileDataToArray(trainDataPath);
+ trainData = new String[tempArray.size()][];
+ tempArray.toArray(trainData);
+
+ classTypes = new ArrayList<>();
+ for (String[] s : tempArray) {
+ if (!classTypes.contains(s[0])) {
+ // 添加类型
+ classTypes.add(s[0]);
+ }
+ }
+
+ tempArray = fileDataToArray(testDataPath);
+ testData = new String[tempArray.size()][];
+ tempArray.toArray(testData);
+ }
+
+ /**
+ * 将文件转为列表数据输出
+ *
+ * @param filePath 数据文件的内容
+ */
+ private ArrayList fileDataToArray(String filePath){
+ File file = new File(filePath);
+ ArrayList dataArray = new ArrayList<>();
+
+ try {
+ BufferedReader in = new BufferedReader(new FileReader(file));
+ String str;
+ String[] tempArray;
+ while ((str = in.readLine()) != null) {
+ tempArray = str.split(" ");
+ dataArray.add(tempArray);
+ }
+ in.close();
+ } catch (IOException e) {
+ e.getStackTrace();
+ }
+
+ return dataArray;
+ }
+
+ /**
+ * 计算样本特征向量的欧几里得距离
+ *
+ * @param s1 待比较样本1
+ * @param s2 待比较样本2
+ */
+ private int computeEuclideanDistance(Sample s1, Sample s2){
+ String[] f1 = s1.getFeatures();
+ String[] f2 = s2.getFeatures();
+ // 欧几里得距离
+ int distance = 0;
+
+ for (int i = 0; i < f1.length; i++) {
+ int subF1 = Integer.parseInt(f1[i]);
+ int subF2 = Integer.parseInt(f2[i]);
+
+ distance += (subF1 - subF2) * (subF1 - subF2);
+ }
+
+ return distance;
+ }
+
+ /**
+ * 计算K最近邻
+ *
+ * @param k 在多少的k范围内
+ */
+ void knnCompute(int k){
+ String className;
+ String[] tempF;
+ Sample temp;
+ ArrayList resultSamples = new ArrayList<>();
+ ArrayList trainSamples = new ArrayList<>();
+ // 分类类别计数
+ HashMap classCount;
+ // 类别权重比
+ HashMap classWeight = new HashMap<>();
+ // 首先讲测试数据转化到结果数据中
+ for (String[] s : testData) {
+ temp = new Sample(s);
+ resultSamples.add(temp);
+ }
+
+ for (String[] s : trainData) {
+ className = s[0];
+ tempF = new String[s.length - 1];
+ System.arraycopy(s, 1, tempF, 0, s.length - 1);
+ temp = new Sample(className, tempF);
+ trainSamples.add(temp);
+ }
+
+ // 离样本最近排序的的训练集数据
+ ArrayList kNNSample = new ArrayList<>();
+ // 计算训练数据集中离样本数据最近的K个训练集数据
+ for (Sample s : resultSamples) {
+ classCount = new HashMap<>();
+ int index = 0;
+ for (String type : classTypes) {
+ // 开始时计数为0
+ classCount.put(type, 0);
+ classWeight.put(type, classWeightArray[index++]);
+ }
+ for (Sample tS : trainSamples) {
+ int dis = computeEuclideanDistance(s, tS);
+ tS.setDistance(dis);
+ }
+
+ Collections.sort(trainSamples);
+ kNNSample.clear();
+ // 挑选出前k个数据作为分类标准
+ for (int i = 0; i < trainSamples.size(); i++) {
+ if (i < k) {
+ kNNSample.add(trainSamples.get(i));
+ } else {
+ break;
+ }
+ }
+ // 判定K个训练数据的多数的分类标准
+ for (Sample s1 : kNNSample) {
+ int num = classCount.get(s1.getClassName());
+ // 进行分类权重的叠加,默认类别权重平等,可自行改变,近的权重大,远的权重小
+ num += classWeight.get(s1.getClassName());
+ classCount.put(s1.getClassName(), num);
+ }
+
+ int maxCount = 0;
+ // 筛选出k个训练集数据中最多的一个分类
+ for (Map.Entry entry : classCount.entrySet()) {
+ if ((Integer) entry.getValue() > maxCount) {
+ maxCount = (Integer) entry.getValue();
+ s.setClassName((String) entry.getKey());
+ }
+ }
+
+ System.out.print("测试数据特征:");
+ for (String s1 : s.getFeatures()) {
+ System.out.print(s1 + " ");
+ }
+ System.out.println("分类:" + s.getClassName());
+ }
+ }
}
diff --git a/Classification/DataMining_KNN/Sample.java b/Classification/DataMining_KNN/Sample.java
index c4e185d..bf6e3e9 100644
--- a/Classification/DataMining_KNN/Sample.java
+++ b/Classification/DataMining_KNN/Sample.java
@@ -1,57 +1,52 @@
-package DataMining_KNN;
+package Classification.DataMining_KNN;
/**
* 样本数据类
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class Sample implements Comparable{
- // 样本数据的分类名称
- private String className;
- // 样本数据的特征向量
- private String[] features;
- //测试样本之间的间距值,以此做排序
- private Integer distance;
-
- public Sample(String[] features){
- this.features = features;
- }
-
- public Sample(String className, String[] features){
- this.className = className;
- this.features = features;
- }
-
- public String getClassName() {
- return className;
- }
-
- public void setClassName(String className) {
- this.className = className;
- }
-
- public String[] getFeatures() {
- return features;
- }
-
- public void setFeatures(String[] features) {
- this.features = features;
- }
-
- public Integer getDistance() {
- return distance;
- }
-
- public void setDistance(int distance) {
- this.distance = distance;
- }
-
- @Override
- public int compareTo(Sample o) {
- // TODO Auto-generated method stub
- return this.getDistance().compareTo(o.getDistance());
- }
-
+class Sample implements Comparable {
+ // 样本数据的分类名称
+ private String className;
+ // 样本数据的特征向量
+ private String[] features;
+ //测试样本之间的间距值,以此做排序
+ private Integer distance;
+
+ Sample(String[] features){
+ this.features = features;
+ }
+
+ Sample(String className, String[] features){
+ this.className = className;
+ this.features = features;
+ }
+
+ String getClassName(){
+ return className;
+ }
+
+ void setClassName(String className){
+ this.className = className;
+ }
+
+ String[] getFeatures(){
+ return features;
+ }
+
+ public Integer getDistance(){
+ return distance;
+ }
+
+ public void setDistance(int distance){
+ this.distance = distance;
+ }
+
+ @Override
+ public int compareTo(Sample o){
+ // TODO Auto-generated method stub
+ return this.getDistance().compareTo(o.getDistance());
+ }
+
}
diff --git a/Classification/DataMining_NaiveBayes/Client.java b/Classification/DataMining_NaiveBayes/Client.java
index 05713b7..3e307b9 100644
--- a/Classification/DataMining_NaiveBayes/Client.java
+++ b/Classification/DataMining_NaiveBayes/Client.java
@@ -1,17 +1,17 @@
-package DataMining_NaiveBayes;
+package Classification.DataMining_NaiveBayes;
/**
* 朴素贝叶斯算法场景调用类
- * @author lyq
*
+ * @author Qstar
*/
public class Client {
- public static void main(String[] args){
- //训练集数据
- String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
- String testData = "Youth Medium Yes Fair";
- NaiveBayesTool tool = new NaiveBayesTool(filePath);
- System.out.println(testData + " 数据的分类为:" + tool.naiveBayesClassificate(testData));
- }
+ public static void main(String[] args){
+ //训练集数据
+ String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Classification/DataMining_NaiveBayes/input.txt";
+ String testData = "Youth Medium Yes Fair";
+ NaiveBayesTool tool = new NaiveBayesTool(filePath);
+ System.out.println(testData + " 数据的分类为:" + tool.naiveBayesClassificate(testData));
+ }
}
diff --git a/Classification/DataMining_NaiveBayes/NaiveBayesTool.java b/Classification/DataMining_NaiveBayes/NaiveBayesTool.java
index 02db2c6..6239058 100644
--- a/Classification/DataMining_NaiveBayes/NaiveBayesTool.java
+++ b/Classification/DataMining_NaiveBayes/NaiveBayesTool.java
@@ -1,4 +1,4 @@
-package DataMining_NaiveBayes;
+package Classification.DataMining_NaiveBayes;
import java.io.BufferedReader;
import java.io.File;
@@ -10,200 +10,193 @@
/**
* 朴素贝叶斯算法工具类
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class NaiveBayesTool {
- // 类标记符,这里分为2类,YES和NO
- private String YES = "Yes";
- private String NO = "No";
-
- // 已分类训练数据集文件路径
- private String filePath;
- // 属性名称数组
- private String[] attrNames;
- // 训练数据集
- private String[][] data;
-
- // 每个属性的值所有类型
- private HashMap> attrValue;
-
- public NaiveBayesTool(String filePath) {
- this.filePath = filePath;
-
- readDataFile();
- initAttrValue();
- }
-
- /**
- * 从文件中读取数据
- */
- private void readDataFile() {
- File file = new File(filePath);
- ArrayList dataArray = new ArrayList();
-
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
-
- data = new String[dataArray.size()][];
- dataArray.toArray(data);
- attrNames = data[0];
+class NaiveBayesTool {
+ // 类标记符,这里分为2类,YES和NO
+ private String YES = "Yes";
+
+ // 已分类训练数据集文件路径
+ private String filePath;
+ // 属性名称数组
+ private String[] attrNames;
+ // 训练数据集
+ private String[][] data;
+
+ // 每个属性的值所有类型
+ private HashMap> attrValue;
+
+ NaiveBayesTool(String filePath){
+ this.filePath = filePath;
+
+ readDataFile();
+ initAttrValue();
+ }
+
+ /**
+ * 从文件中读取数据
+ */
+ private void readDataFile(){
+ File file = new File(filePath);
+ ArrayList dataArray = new ArrayList<>();
+
+ try {
+ BufferedReader in = new BufferedReader(new FileReader(file));
+ String str;
+ String[] tempArray;
+ while ((str = in.readLine()) != null) {
+ tempArray = str.split(" ");
+ dataArray.add(tempArray);
+ }
+ in.close();
+ } catch (IOException e) {
+ e.getStackTrace();
+ }
+
+ data = new String[dataArray.size()][];
+ dataArray.toArray(data);
+ attrNames = data[0];
/*
- * for(int i=0; i();
- ArrayList tempValues;
-
- // 按照列的方式,从左往右找
- for (int j = 1; j < attrNames.length; j++) {
- // 从一列中的上往下开始寻找值
- tempValues = new ArrayList<>();
- for (int i = 1; i < data.length; i++) {
- if (!tempValues.contains(data[i][j])) {
- // 如果这个属性的值没有添加过,则添加
- tempValues.add(data[i][j]);
- }
- }
-
- // 一列属性的值已经遍历完毕,复制到map属性表中
- attrValue.put(data[0][j], tempValues);
- }
-
- }
-
- /**
- * 在classType的情况下,发生condition条件的概率
- *
- * @param condition
- * 属性条件
- * @param classType
- * 分类的类型
- * @return
- */
- private double computeConditionProbably(String condition, String classType) {
- // 条件计数器
- int count = 0;
- // 条件属性的索引列
- int attrIndex = 1;
- // yes类标记符数据
- ArrayList yClassData = new ArrayList<>();
- // no类标记符数据
- ArrayList nClassData = new ArrayList<>();
- ArrayList classData;
-
- for (int i = 1; i < data.length; i++) {
- // data数据按照yes和no分类
- if (data[i][attrNames.length - 1].equals(YES)) {
- yClassData.add(data[i]);
- } else {
- nClassData.add(data[i]);
- }
- }
-
- if (classType.equals(YES)) {
- classData = yClassData;
- } else {
- classData = nClassData;
- }
-
- // 如果没有设置条件则,计算的是纯粹的类事件概率
- if (condition == null) {
- return 1.0 * classData.size() / (data.length - 1);
- }
-
- // 寻找此条件的属性列
- attrIndex = getConditionAttrName(condition);
-
- for (String[] s : classData) {
- if (s[attrIndex].equals(condition)) {
- count++;
- }
- }
-
- return 1.0 * count / classData.size();
- }
-
- /**
- * 根据条件值返回条件所属属性的列值
- *
- * @param condition
- * 条件
- * @return
- */
- private int getConditionAttrName(String condition) {
- // 条件所属属性名
- String attrName = "";
- // 条件所在属性列索引
- int attrIndex = 1;
- // 临时属性值类型
- ArrayList valueTypes;
- for (Map.Entry entry : attrValue.entrySet()) {
- valueTypes = (ArrayList) entry.getValue();
- if (valueTypes.contains(condition)
- && !((String) entry.getKey()).equals("BuysComputer")) {
- attrName = (String) entry.getKey();
- }
- }
-
- for (int i = 0; i < attrNames.length - 1; i++) {
- if (attrNames[i].equals(attrName)) {
- attrIndex = i;
- break;
- }
- }
-
- return attrIndex;
- }
-
- /**
- * 进行朴素贝叶斯分类
- *
- * @param data
- * 待分类数据
- */
- public String naiveBayesClassificate(String data) {
- // 测试数据的属性值特征
- String[] dataFeatures;
- // 在yes的条件下,x事件发生的概率
- double xWhenYes = 1.0;
- // 在no的条件下,x事件发生的概率
- double xWhenNo = 1.0;
- // 最后也是yes和no分类的总概率,用P(X|Ci)*P(Ci)的公式计算
- double pYes = 1;
- double pNo = 1;
-
- dataFeatures = data.split(" ");
- for (int i = 0; i < dataFeatures.length; i++) {
- // 因为朴素贝叶斯算法是类条件独立的,所以可以进行累积的计算
- xWhenYes *= computeConditionProbably(dataFeatures[i], YES);
- xWhenNo *= computeConditionProbably(dataFeatures[i], NO);
- }
-
- pYes = xWhenYes * computeConditionProbably(null, YES);
- pNo = xWhenNo * computeConditionProbably(null, NO);
-
- return (pYes > pNo ? YES : NO);
- }
+ }
+
+ /**
+ * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
+ */
+ private void initAttrValue(){
+ attrValue = new HashMap<>();
+ ArrayList tempValues;
+
+ // 按照列的方式,从左往右找
+ for (int j = 1; j < attrNames.length; j++) {
+ // 从一列中的上往下开始寻找值
+ tempValues = new ArrayList<>();
+ for (int i = 1; i < data.length; i++) {
+ if (!tempValues.contains(data[i][j])) {
+ // 如果这个属性的值没有添加过,则添加
+ tempValues.add(data[i][j]);
+ }
+ }
+
+ // 一列属性的值已经遍历完毕,复制到map属性表中
+ attrValue.put(data[0][j], tempValues);
+ }
+
+ }
+
+ /**
+ * 在classType的情况下,发生condition条件的概率
+ *
+ * @param condition 属性条件
+ * @param classType 分类的类型
+ */
+ private double computeConditionProbably(String condition, String classType){
+ // 条件计数器
+ int count = 0;
+ // 条件属性的索引列
+ int attrIndex;
+ // yes类标记符数据
+ ArrayList yClassData = new ArrayList<>();
+ // no类标记符数据
+ ArrayList nClassData = new ArrayList<>();
+ ArrayList classData;
+
+ for (int i = 1; i < data.length; i++) {
+ // data数据按照yes和no分类
+ if (data[i][attrNames.length - 1].equals(YES)) {
+ yClassData.add(data[i]);
+ } else {
+ nClassData.add(data[i]);
+ }
+ }
+
+ if (classType.equals(YES)) {
+ classData = yClassData;
+ } else {
+ classData = nClassData;
+ }
+
+ // 如果没有设置条件则,计算的是纯粹的类事件概率
+ if (condition == null) {
+ return 1.0 * classData.size() / (data.length - 1);
+ }
+
+ // 寻找此条件的属性列
+ attrIndex = getConditionAttrName(condition);
+
+ for (String[] s : classData) {
+ if (s[attrIndex].equals(condition)) {
+ count++;
+ }
+ }
+
+ return 1.0 * count / classData.size();
+ }
+
+ /**
+ * 根据条件值返回条件所属属性的列值
+ *
+ * @param condition 条件
+ */
+ private int getConditionAttrName(String condition){
+ // 条件所属属性名
+ String attrName = "";
+ // 条件所在属性列索引
+ int attrIndex = 1;
+ // 临时属性值类型
+ ArrayList valueTypes;
+ for (Map.Entry entry : attrValue.entrySet()) {
+ valueTypes = (ArrayList) entry.getValue();
+ if (valueTypes.contains(condition)
+ && !entry.getKey().equals("BuysComputer")) {
+ attrName = (String) entry.getKey();
+ }
+ }
+
+ for (int i = 0; i < attrNames.length - 1; i++) {
+ if (attrNames[i].equals(attrName)) {
+ attrIndex = i;
+ break;
+ }
+ }
+
+ return attrIndex;
+ }
+
+ /**
+ * 进行朴素贝叶斯分类
+ *
+ * @param data 待分类数据
+ */
+ String naiveBayesClassificate(String data){
+ // 测试数据的属性值特征
+ String[] dataFeatures;
+ // 在yes的条件下,x事件发生的概率
+ double xWhenYes = 1.0;
+ // 在no的条件下,x事件发生的概率
+ double xWhenNo = 1.0;
+ // 最后也是yes和no分类的总概率,用P(X|Ci)*P(Ci)的公式计算
+ double pYes;
+ double pNo;
+
+ dataFeatures = data.split(" ");
+ String NO = "No";
+ for (String dataFeature : dataFeatures) {
+ // 因为朴素贝叶斯算法是类条件独立的,所以可以进行累积的计算
+ xWhenYes *= computeConditionProbably(dataFeature, YES);
+ xWhenNo *= computeConditionProbably(dataFeature, NO);
+ }
+
+ pYes = xWhenYes * computeConditionProbably(null, YES);
+ pNo = xWhenNo * computeConditionProbably(null, NO);
+
+ return (pYes > pNo ? YES : NO);
+ }
}
diff --git a/Clustering/DataMining_BIRCH/BIRCHTool.java b/Clustering/DataMining_BIRCH/BIRCHTool.java
index 3aec6aa..be3f039 100644
--- a/Clustering/DataMining_BIRCH/BIRCHTool.java
+++ b/Clustering/DataMining_BIRCH/BIRCHTool.java
@@ -1,4 +1,4 @@
-package DataMining_BIRCH;
+package Clustering.DataMining_BIRCH;
import java.io.BufferedReader;
import java.io.File;
@@ -10,243 +10,240 @@
/**
* BIRCH聚类算法工具类
- *
- * @author lyq
- *
+ *
+ * @author Qstar
*/
-public class BIRCHTool {
- // 节点类型名称
- public static final String NON_LEAFNODE = "【NonLeafNode】";
- public static final String LEAFNODE = "【LeafNode】";
- public static final String CLUSTER = "【Cluster】";
-
- // 测试数据文件地址
- private String filePath;
- // 内部节点平衡因子B
- public static int B;
- // 叶子节点平衡因子L
- public static int L;
- // 簇直径阈值T
- public static double T;
- // 总的测试数据记录
- private ArrayList totalDataRecords;
-
- public BIRCHTool(String filePath, int B, int L, double T) {
- this.filePath = filePath;
- this.B = B;
- this.L = L;
- this.T = T;
- readDataFile();
- }
-
- /**
- * 从文件中读取数据
- */
- private void readDataFile() {
- File file = new File(filePath);
- ArrayList dataArray = new ArrayList();
-
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
-
- totalDataRecords = new ArrayList<>();
- for (String[] array : dataArray) {
- totalDataRecords.add(array);
- }
- }
-
- /**
- * 构建CF聚类特征树
- *
- * @return
- */
- private ClusteringFeature buildCFTree() {
- NonLeafNode rootNode = null;
- LeafNode leafNode = null;
- Cluster cluster = null;
-
- for (String[] record : totalDataRecords) {
- cluster = new Cluster(record);
-
- if (rootNode == null) {
- // CF树只有1个节点的时候的情况
- if (leafNode == null) {
- leafNode = new LeafNode();
- }
- leafNode.addingCluster(cluster);
- if (leafNode.getParentNode() != null) {
- rootNode = leafNode.getParentNode();
- }
- } else {
- if (rootNode.getParentNode() != null) {
- rootNode = rootNode.getParentNode();
- }
-
- // 从根节点开始,从上往下寻找到最近的添加目标叶子节点
- LeafNode temp = rootNode.findedClosestNode(cluster);
- temp.addingCluster(cluster);
- }
- }
-
- // 从下往上找出最上面的节点
- LeafNode node = cluster.getParentNode();
- NonLeafNode upNode = node.getParentNode();
- if (upNode == null) {
- return node;
- } else {
- while (upNode.getParentNode() != null) {
- upNode = upNode.getParentNode();
- }
-
- return upNode;
- }
- }
-
- /**
- * 开始构建CF聚类特征树
- */
- public void startBuilding() {
- // 树深度
- int level = 1;
- ClusteringFeature rootNode = buildCFTree();
-
- setTreeLevel(rootNode, level);
- showCFTree(rootNode);
- }
-
- /**
- * 设置节点深度
- *
- * @param clusteringFeature
- * 当前节点
- * @param level
- * 当前深度值
- */
- private void setTreeLevel(ClusteringFeature clusteringFeature, int level) {
- LeafNode leafNode = null;
- NonLeafNode nonLeafNode = null;
-
- if (clusteringFeature instanceof LeafNode) {
- leafNode = (LeafNode) clusteringFeature;
- } else if (clusteringFeature instanceof NonLeafNode) {
- nonLeafNode = (NonLeafNode) clusteringFeature;
- }
-
- if (nonLeafNode != null) {
- nonLeafNode.setLevel(level);
- level++;
- // 设置子节点
- if (nonLeafNode.getNonLeafChilds() != null) {
- for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) {
- setTreeLevel(n1, level);
- }
- } else {
- for (LeafNode n2 : nonLeafNode.getLeafChilds()) {
- setTreeLevel(n2, level);
- }
- }
- } else {
- leafNode.setLevel(level);
- level++;
- // 设置子聚簇
- for (Cluster c : leafNode.getClusterChilds()) {
- c.setLevel(level);
- }
- }
- }
-
- /**
- * 显示CF聚类特征树
- *
- * @param rootNode
- * CF树根节点
- */
- private void showCFTree(ClusteringFeature rootNode) {
- // 空格数,用于输出
- int blankNum = 5;
- // 当前树深度
- int currentLevel = 1;
- LinkedList nodeQueue = new LinkedList<>();
- ClusteringFeature cf;
- LeafNode leafNode;
- NonLeafNode nonLeafNode;
- ArrayList clusterList = new ArrayList<>();
- String typeName;
-
- nodeQueue.add(rootNode);
- while (nodeQueue.size() > 0) {
- cf = nodeQueue.poll();
-
- if (cf instanceof LeafNode) {
- leafNode = (LeafNode) cf;
- typeName = LEAFNODE;
-
- if (leafNode.getClusterChilds() != null) {
- for (Cluster c : leafNode.getClusterChilds()) {
- nodeQueue.add(c);
- }
- }
- } else if (cf instanceof NonLeafNode) {
- nonLeafNode = (NonLeafNode) cf;
- typeName = NON_LEAFNODE;
-
- if (nonLeafNode.getNonLeafChilds() != null) {
- for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) {
- nodeQueue.add(n1);
- }
- } else {
- for (LeafNode n2 : nonLeafNode.getLeafChilds()) {
- nodeQueue.add(n2);
- }
- }
- } else {
- clusterList.add((Cluster)cf);
- typeName = CLUSTER;
- }
-
- if (currentLevel != cf.getLevel()) {
- currentLevel = cf.getLevel();
- System.out.println();
- System.out.println("|");
- System.out.println("|");
- }else if(currentLevel == cf.getLevel() && currentLevel != 1){
- for (int i = 0; i < blankNum; i++) {
- System.out.print("-");
- }
- }
-
- System.out.print(typeName);
- System.out.print("N:" + cf.getN() + ", LS:");
- System.out.print("[");
- for (double d : cf.getLS()) {
- System.out.print(MessageFormat.format("{0}, ", d));
- }
- System.out.print("]");
- }
-
- System.out.println();
- System.out.println("*******最终分好的聚簇****");
- //显示已经分好类的聚簇点
- for(int i=0; i totalDataRecords;
+
+ BIRCHTool(String filePath, int B, int L, double T){
+ this.filePath = filePath;
+ BIRCHTool.B = B;
+ BIRCHTool.L = L;
+ BIRCHTool.T = T;
+ readDataFile();
+ }
+
+ /**
+ * 从文件中读取数据
+ */
+ private void readDataFile(){
+ File file = new File(filePath);
+ ArrayList dataArray = new ArrayList<>();
+
+ try {
+ BufferedReader in = new BufferedReader(new FileReader(file));
+ String str;
+ String[] tempArray;
+ while ((str = in.readLine()) != null) {
+ tempArray = str.split(" ");
+ dataArray.add(tempArray);
+ }
+ in.close();
+ } catch (IOException e) {
+ e.getStackTrace();
+ }
+
+ totalDataRecords = new ArrayList<>();
+ for (String[] array : dataArray) {
+ totalDataRecords.add(array);
+ }
+ }
+
+ /**
+ * 构建CF聚类特征树
+ */
+ private ClusteringFeature buildCFTree(){
+ NonLeafNode rootNode = null;
+ LeafNode leafNode = null;
+ Cluster cluster = null;
+
+ for (String[] record : totalDataRecords) {
+ cluster = new Cluster(record);
+
+ if (rootNode == null) {
+ // CF树只有1个节点的时候的情况
+ if (leafNode == null) {
+ leafNode = new LeafNode();
+ }
+ leafNode.addingCluster(cluster);
+ if (leafNode.getParentNode() != null) {
+ rootNode = leafNode.getParentNode();
+ }
+ } else {
+ if (rootNode.getParentNode() != null) {
+ rootNode = rootNode.getParentNode();
+ }
+
+ // 从根节点开始,从上往下寻找到最近的添加目标叶子节点
+ LeafNode temp = rootNode.findedClosestNode(cluster);
+ temp.addingCluster(cluster);
+ }
+ }
+
+ // 从下往上找出最上面的节点
+ LeafNode node = cluster != null ? cluster.getParentNode() : null;
+ NonLeafNode upNode = node != null ? node.getParentNode() : null;
+ if (upNode == null) {
+ return node;
+ } else {
+ while (upNode.getParentNode() != null) {
+ upNode = upNode.getParentNode();
+ }
+
+ return upNode;
+ }
+ }
+
+ /**
+ * 开始构建CF聚类特征树
+ */
+ void startBuilding(){
+ // 树深度
+ int level = 1;
+ ClusteringFeature rootNode = buildCFTree();
+
+ setTreeLevel(rootNode, level);
+ showCFTree(rootNode);
+ }
+
+ /**
+ * 设置节点深度
+ *
+ * @param clusteringFeature 当前节点
+ * @param level 当前深度值
+ */
+ private void setTreeLevel(ClusteringFeature clusteringFeature, int level){
+ LeafNode leafNode = null;
+ NonLeafNode nonLeafNode = null;
+
+ if (clusteringFeature instanceof LeafNode) {
+ leafNode = (LeafNode) clusteringFeature;
+ } else if (clusteringFeature instanceof NonLeafNode) {
+ nonLeafNode = (NonLeafNode) clusteringFeature;
+ }
+
+ if (nonLeafNode != null) {
+ nonLeafNode.setLevel(level);
+ level++;
+ // 设置子节点
+ if (nonLeafNode.getNonLeafChilds() != null) {
+ for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) {
+ setTreeLevel(n1, level);
+ }
+ } else {
+ for (LeafNode n2 : nonLeafNode.getLeafChilds()) {
+ setTreeLevel(n2, level);
+ }
+ }
+ } else {
+ if (leafNode != null) {
+ leafNode.setLevel(level);
+ }
+ level++;
+ // 设置子聚簇
+ if (leafNode != null) {
+ for (Cluster c : leafNode.getClusterChilds()) {
+ c.setLevel(level);
+ }
+ }
+ }
+ }
+
+ /**
+ * 显示CF聚类特征树
+ *
+ * @param rootNode CF树根节点
+ */
+ private void showCFTree(ClusteringFeature rootNode){
+ // 空格数,用于输出
+ int blankNum = 5;
+ // 当前树深度
+ int currentLevel = 1;
+ LinkedList nodeQueue = new LinkedList<>();
+ ClusteringFeature cf;
+ LeafNode leafNode;
+ NonLeafNode nonLeafNode;
+ ArrayList clusterList = new ArrayList<>();
+ String typeName;
+
+ nodeQueue.add(rootNode);
+ while (nodeQueue.size() > 0) {
+ cf = nodeQueue.poll();
+
+ if (cf instanceof LeafNode) {
+ leafNode = (LeafNode) cf;
+ typeName = LEAFNODE;
+
+ if (leafNode.getClusterChilds() != null) {
+ for (Cluster c : leafNode.getClusterChilds()) {
+ nodeQueue.add(c);
+ }
+ }
+ } else if (cf instanceof NonLeafNode) {
+ nonLeafNode = (NonLeafNode) cf;
+ typeName = NON_LEAFNODE;
+
+ if (nonLeafNode.getNonLeafChilds() != null) {
+ for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) {
+ nodeQueue.add(n1);
+ }
+ } else {
+ for (LeafNode n2 : nonLeafNode.getLeafChilds()) {
+ nodeQueue.add(n2);
+ }
+ }
+ } else {
+ clusterList.add((Cluster) cf);
+ typeName = CLUSTER;
+ }
+
+ if (currentLevel != cf.getLevel()) {
+ currentLevel = cf.getLevel();
+ System.out.println();
+ System.out.println("|");
+ System.out.println("|");
+ } else if (currentLevel == cf.getLevel() && currentLevel != 1) {
+ for (int i = 0; i < blankNum; i++) {
+ System.out.print("-");
+ }
+ }
+
+ System.out.print(typeName);
+ System.out.print("N:" + cf.getN() + ", LS:");
+ System.out.print("[");
+ for (double d : cf.getLS()) {
+ System.out.print(MessageFormat.format("{0}, ", d));
+ }
+ System.out.print("]");
+ }
+
+ System.out.println();
+ System.out.println("*******最终分好的聚簇****");
+ //显示已经分好类的聚簇点
+ for (int i = 0; i < clusterList.size(); i++) {
+ System.out.println("Cluster" + (i + 1) + ":");
+ for (double[] point : clusterList.get(i).getData()) {
+ System.out.print("[");
+ for (double d : point) {
+ System.out.print(MessageFormat.format("{0}, ", d));
+ }
+ System.out.println("]");
+ }
+ }
+ }
}
diff --git a/Clustering/DataMining_BIRCH/Client.java b/Clustering/DataMining_BIRCH/Client.java
index 0f6ea28..125de7b 100644
--- a/Clustering/DataMining_BIRCH/Client.java
+++ b/Clustering/DataMining_BIRCH/Client.java
@@ -1,21 +1,21 @@
-package DataMining_BIRCH;
+package Clustering.DataMining_BIRCH;
/**
* BIRCH聚类算法调用类
- * @author lyq
*
+ * @author Qstar
*/
public class Client {
- public static void main(String[] args){
- String filePath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt";
- //内部节点平衡因子B
- int B = 2;
- //叶子节点平衡因子L
- int L = 2;
- //簇直径阈值T
- double T = 0.6;
-
- BIRCHTool tool = new BIRCHTool(filePath, B, L, T);
- tool.startBuilding();
- }
+ public static void main(String[] args){
+ String filePath = "/Users/Qstar/Desktop/DataMiningAlgorithm/Clustering/DataMining_BIRCH/testInput.txt";
+ //内部节点平衡因子B
+ int B = 2;
+ //叶子节点平衡因子L
+ int L = 2;
+ //簇直径阈值T
+ double T = 0.6;
+
+ BIRCHTool tool = new BIRCHTool(filePath, B, L, T);
+ tool.startBuilding();
+ }
}
diff --git a/Clustering/DataMining_BIRCH/Cluster.java b/Clustering/DataMining_BIRCH/Cluster.java
index 0edda4f..5269264 100644
--- a/Clustering/DataMining_BIRCH/Cluster.java
+++ b/Clustering/DataMining_BIRCH/Cluster.java
@@ -1,60 +1,56 @@
-package DataMining_BIRCH;
+package Clustering.DataMining_BIRCH;
import java.util.ArrayList;
/**
* 叶子节点中的小集群
- * @author lyq
*
+ * @author Qstar
*/
-public class Cluster extends ClusteringFeature{
- //集群中的数据点
- private ArrayList data;
- //父亲节点
- private LeafNode parentNode;
-
- public Cluster(String[] record){
- double[] d = new double[record.length];
- data = new ArrayList<>();
- for(int i=0; i getData() {
- return data;
- }
-
- public void setData(ArrayList data) {
- this.data = data;
- }
-
- @Override
- protected void directAddCluster(ClusteringFeature node) {
- //如果是聚类包括数据记录,则还需合并数据记录
- Cluster c = (Cluster)node;
- ArrayList dataRecords = c.getData();
- this.data.addAll(dataRecords);
-
- super.directAddCluster(node);
- }
-
- public LeafNode getParentNode() {
- return parentNode;
- }
-
- public void setParentNode(LeafNode parentNode) {
- this.parentNode = parentNode;
- }
-
- @Override
- public void addingCluster(ClusteringFeature clusteringFeature) {
- // TODO Auto-generated method stub
-
- }
+class Cluster extends ClusteringFeature {
+ //集群中的数据点
+ private ArrayList data;
+ //父亲节点
+ private LeafNode parentNode;
+
+ Cluster(String[] record){
+ double[] d = new double[record.length];
+ data = new ArrayList<>();
+ for (int i = 0; i < record.length; i++) {
+ d[i] = Double.parseDouble(record[i]);
+ }
+ data.add(d);
+ //计算CF聚类特征
+ this.setLS(data);
+ this.setSS(data);
+ this.setN(data);
+ }
+
+ ArrayList getData(){
+ return data;
+ }
+
+ @Override
+ protected void directAddCluster(ClusteringFeature node){
+ //如果是聚类包括数据记录,则还需合并数据记录
+ Cluster c = (Cluster) node;
+ ArrayList