Skip to content

Commit

Permalink
fixed: 最优化ref计算
Browse files Browse the repository at this point in the history
  • Loading branch information
F-ca7 committed Oct 8, 2019
1 parent 759e04f commit f3b6f7c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public abstract class OptimizationAlgorithm {
public abstract ArrayList<Double> getModifiedColumn();



/**
* 获取ref值,供sigmoid函数使用
* @author Fang
Expand All @@ -29,41 +30,21 @@ public abstract class OptimizationAlgorithm {
* @update Kun-提出中位数分割法
*/
public static double getRefValue(ArrayList<Double> colValues, double secretKey){
// DescriptiveStatistics stats = new DescriptiveStatistics();
// //把数组的值添加到统计量中
// colValues.forEach(stats::addValue);
// double mean = stats.getMean();
// double varianceSqrt = Math.sqrt(stats.getVariance());
// double ref = mean + secretKey*varianceSqrt;

ArrayList<Double> tmpList = new ArrayList<>(colValues);
tmpList.sort(Comparator.naturalOrder());
double mid = tmpList.get(colValues.size()/2);

// if (varianceSqrt>mean) {
//
// // 若数据集过于分散,则进行均化处理
// double sum = 0;
// for (double i:colValues){
// sum += getSigmoid(i, ref);
// }
// logger.info("标准差大于均值,返回{}", sum/(double)colValues.size());
// return sum/(double)colValues.size();
// }
//logger.info("标准差小于均值,返回{}", mid);
return mid;

//把数组的值添加到统计量中
logger.info("ref为{}",PatternSearch.OREF);
return PatternSearch.OREF;
}


public static double getOHidingValue(ArrayList<Double> colValues, double secretKey){
double ref = getRefValue(colValues, secretKey);
double sum = 0.0;
for (double i:colValues){
sum += getSigmoid(i, PatternSearch.OREF);
sum += getSigmoid(i, ref);
}
return sum/(double)colValues.size();
}


private static double getSigmoid(double i, double oref) {
double ALPHA = 8;
return (1.0-1.0/(1+Math.exp(ALPHA*(i-oref))));
Expand All @@ -88,11 +69,10 @@ public static double calcOptimizedThreshold(ArrayList<Double> minList, ArrayList

double minMean = minStats.getMean();
double maxMean = maxStats.getMean();
double minVar = minStats.getVariance();
double maxVar = maxStats.getVariance();

logger.info("min均值:{}, 方差:{}", minMean, minVar);
logger.info("max均值:{}, 方差:{}", maxMean, maxVar);
// double minVar = minStats.getVariance();
// double maxVar = maxStats.getVariance();
// logger.info("min均值:{}, 方差:{}", minMean, minVar);
// logger.info("max均值:{}, 方差:{}", maxMean, maxVar);

return (minMean+maxMean)/2;

Expand All @@ -103,9 +83,6 @@ public static double calcOptimizedThreshold(ArrayList<Double> minList, ArrayList
* @Description 返回一元二次方程较小的根
* @author Fang
* @date 2019/3/26 17:07
* @param A
* @param B
* @param C
* @return double
*/
private static double getSmallerRootForQuad(double A, double B, double C){
Expand Down
42 changes: 24 additions & 18 deletions WMtest/src/main/java/team/aster/algorithm/PatternSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import java.util.ArrayList;
import java.util.Collections;
/**
*
* 模式搜索--最优化
* @author kun
*
*/
public final class PatternSearch extends OptimizationAlgorithm{

Expand All @@ -21,10 +20,15 @@ public final class PatternSearch extends OptimizationAlgorithm{
private ArrayList<Double> changeRecord;//增向量
private double ALPHA = 8;//sigmoid函数参数
private double REF;//临时sigmoid函数参数
static double OREF;//sigmoid隐藏函数参数
public static double OREF;//sigmoid隐藏函数参数
private boolean IS_MAX;
private double exp = -0.00001;

/**
* 设置上下界
* @param lower 下界
* @param upper 上界
*/
private void setBound(double lower,double upper) {
UPPER_BOUND = upper;
LOWER_BOUND = lower;
Expand All @@ -35,12 +39,16 @@ private void setRef(double ref) {
REF = ref;
OREF = ref;
}

private double getSigmoid(double x) {
return (1.0-1.0/(1+Math.exp(ALPHA*(x-REF))));

}

private void setNextStepLength() {
STEP_LENGTH = STEP_LENGTH * DECAY_RATE;
}

private boolean cmp(double x,double y) {
if(IS_MAX) {
return (x-y)>exp;
Expand All @@ -66,7 +74,7 @@ private void stepByStep() {
int len = tmp.size();
for(int i=0;i<len;i++) {
double valueI = tmp.get(i);
if(cmp(getSigmoid(valueI+STEP_LENGTH),getSigmoid(valueI))&&changeRecord.get(i)<UPPER_BOUND) {
if(cmp(getSigmoid(valueI+STEP_LENGTH), getSigmoid(valueI)) && changeRecord.get(i)<UPPER_BOUND) {
tmp.set(i, valueI+STEP_LENGTH);
changeRecord.set(i, changeRecord.get(i)+STEP_LENGTH);
}else if(cmp(getSigmoid(valueI-STEP_LENGTH),getSigmoid(valueI))&&changeRecord.get(i)>LOWER_BOUND) {
Expand Down Expand Up @@ -101,8 +109,8 @@ private void searchByPattern() {
break;
}
tmp.set(i, recordState.get(i)+ACCURATE*(recordState.get(i)-initState.get(i)));
x+=getSigmoid(tmp.get(i));
y+=getSigmoid(recordState.get(i));
x += getSigmoid(tmp.get(i));
y += getSigmoid(recordState.get(i));
tmpChange.set(i, tmpChange.get(i)+ACCURATE*(recordState.get(i)-initState.get(i)));
}

Expand Down Expand Up @@ -131,19 +139,17 @@ public ArrayList<Double> getResult(){
}




public PatternSearch() {

}

/**
* 获取最大最优化后的隐藏函数均值meanMax
* @param colValues
* @param ref
* @param lower
* @param upper
* @return
* @param colValues 一列数据
* @param ref 参考值
* @param lower 修改变化下界
* @param upper 修改变化上界
* @return 隐藏函数最大值
*/
@Override
public double maximizeByHidingFunction(ArrayList<Double> colValues, double ref, double lower, double upper) {
Expand All @@ -163,11 +169,11 @@ private double initParams(ArrayList<Double> colValues, double ref, double lower,

/**
* 获取最小最优化后的隐藏函数均值meanMin
* @param colValues
* @param ref
* @param lower
* @param upper
* @return
* @param colValues 一列数据
* @param ref 参考值
* @param lower 修改变化下界
* @param upper 修改变化上界
* @return 隐藏函数最小值
*/
@Override
public double minimizeByHidingFunction(ArrayList<Double> colValues, double ref, double lower, double upper) {
Expand Down
36 changes: 30 additions & 6 deletions WMtest/src/main/java/team/aster/processor/OptimDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import team.aster.algorithm.Divider;
import team.aster.algorithm.GenericOptimization;
import team.aster.algorithm.OptimizationAlgorithm;
import team.aster.algorithm.PatternSearch;
import team.aster.model.DatasetWithPK;
import team.aster.model.PartitionedDataset;
import team.aster.model.StoredKey;
import team.aster.utils.Constants;

import java.util.ArrayList;
import java.util.Map;

/**
* 基于最优化算法的水印解码器
*/
public class OptimDecoder implements IDecoder {
private static Logger logger = LoggerFactory.getLogger(OptimDecoder.class);

Expand All @@ -24,7 +26,6 @@ public class OptimDecoder implements IDecoder {
private int minLength;
//先只对一列进行嵌入水印解码,这里是最后一列FLATLOSE 转让盈亏(已扣税)
//但是由于列之间的约束,这里还是不太科学
// todo 这个地方应该自定义
private int COL_INDEX;

public int getPartitionCount() {
Expand Down Expand Up @@ -87,14 +88,34 @@ private String detectWatermark(PartitionedDataset partitionedDataset){
int[] ones = new int[wmLength];
int[] zeros = new int[wmLength];
Map<Integer, ArrayList<ArrayList<String>>> map = partitionedDataset.getPartitionedDataset();

ArrayList<Double> all = new ArrayList<>();
map.forEach((k,v)->{
for(ArrayList<String> row: v){
// 只取一列数据
double value = Double.valueOf(row.get(COL_INDEX));
all.add(value);
}
});
all.sort(Double::compareTo);
// 使用secretKey进行映射
int start = ((int)(20+secretKey*10))*all.size()/100;
int end = ((int)(80-secretKey*10))*all.size()/100;
double mean = 0d;

for(int i=start;i<end;i++){
mean+=all.get(i)/(end-start);
}
PatternSearch.OREF = mean;

map.forEach((k, v)->{
if(v.size() >= minLength){
ArrayList<Double> colValues = new ArrayList<>();
int index = k%wmLength;
v.forEach(strValues->{
colValues.add(Double.parseDouble(strValues.get(COL_INDEX)));
});
double hidingValue = GenericOptimization.getOHidingValue(colValues, secretKey);
double hidingValue = OptimizationAlgorithm.getOHidingValue(colValues, secretKey);
if (hidingValue > this.threshold) {
ones[index]++;
} else{
Expand All @@ -106,11 +127,14 @@ private String detectWatermark(PartitionedDataset partitionedDataset){
//据ones和zeros生成水印
StringBuilder wm = new StringBuilder();
for(int i=0;i<wmLength;i++){
if(ones[i]>zeros[i]){
if(ones[i] > zeros[i]){
logger.info("第{}位有{}个0,{}个1,解得1", i, zeros[i], ones[i]);
wm.append("1");
}else if(ones[i]<zeros[i]){
logger.info("第{}位有{}个0,{}个1,解得0", i, zeros[i], ones[i]);
wm.append("0");
}else{
logger.info("第{}位有{}个0,{}个1,解得x", i, zeros[i], ones[i]);
wm.append("x");
}
}
Expand Down
45 changes: 34 additions & 11 deletions WMtest/src/main/java/team/aster/processor/OptimEncoder.java
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
package team.aster.processor;

import gui.InstantInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import team.aster.algorithm.Divider;
import team.aster.algorithm.GenericOptimization;
import team.aster.algorithm.OptimizationAlgorithm;
import team.aster.algorithm.PatternSearch;
import team.aster.model.DatasetWithPK;
import team.aster.model.PartitionedDataset;
import team.aster.model.WaterMark;
import team.aster.model.WatermarkException;
import team.aster.utils.Constants;

import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

/**
* 基于最优化算法的水印编码器
*/
public class OptimEncoder extends IEncoderNumericImpl {
private static Logger logger = LoggerFactory.getLogger(OptimDecoder.class);
private static Logger logger = LoggerFactory.getLogger(OptimEncoder.class);

private ArrayList<Double> minList = new ArrayList<>();
private ArrayList<Double> maxList = new ArrayList<>();
Expand Down Expand Up @@ -90,21 +89,43 @@ public void setEmbedColIndex(int embedColIndex) {

/**
* @Description 对划分好的数据集嵌入水印,直接修改划分里的数据集
* @author Fcat
* @author Fang
* @date 2019/3/24 16:47
* @param partitionedDataset 整个划分好的数据集
* @param watermark 要嵌入的水印串
*/
private double encodeAllBits(PartitionedDataset partitionedDataset, ArrayList<Integer> watermark, double secretKey){
System.out.println("开始嵌入水印所有位");
logger.info("开始嵌入水印所有位");
Map<Integer, ArrayList<ArrayList<String>>> datasetWithIndex = partitionedDataset.getPartitionedDataset();
final int wmLength = watermark.size();


ArrayList<Double> all = new ArrayList<>();
datasetWithIndex.forEach((k,v)->{
for(ArrayList<String> row: v){
// 只取一列数据
double value = Double.valueOf(row.get(COL_INDEX));
all.add(value);
}
});
all.sort(Double::compareTo);

int start = ((int)(20+secretKey*10))*all.size()/100;
int end = ((int)(80-secretKey*10))*all.size()/100;
double mean = 0d;
ArrayList<Double> cutCol = new ArrayList<>();
for(int i=start;i<end;i++){
cutCol.add(all.get(i));
mean+=all.get(i)/(end-start);
}
PatternSearch.OREF = mean;


datasetWithIndex.forEach((k,v)->{
int index = k%wmLength;
encodeSingleBit(v, secretKey, watermark.get(index));
});
double threshold = GenericOptimization.calcOptimizedThreshold(minList, maxList);
double threshold = OptimizationAlgorithm.calcOptimizedThreshold(minList, maxList);
logger.debug("阈值为: {}", threshold);
return threshold;
}
Expand All @@ -113,7 +134,7 @@ private double encodeAllBits(PartitionedDataset partitionedDataset, ArrayList<In

/**
* @Description 对水印的一个bit嵌入一个划分当中,直接对划分进行修改
* @author Fcat
* @author Fang
* @date 2019/3/24 1:18
* @param partition 一个划分
* @param secretKey 水印对应的bit位
Expand All @@ -138,12 +159,12 @@ private void encodeSingleBit(ArrayList<ArrayList<String>> partition, double secr
switch (bit){
case 0:
tmp = optimization.minimizeByHidingFunction(colValues,
OptimizationAlgorithm.getHidingValue(colValues, secretKey), varLowerBound, varUpperBound);
OptimizationAlgorithm.getRefValue(colValues, secretKey), varLowerBound, varUpperBound);
minList.add(tmp);
break;
case 1:
tmp = optimization.maximizeByHidingFunction(colValues,
OptimizationAlgorithm.getHidingValue(colValues, secretKey), varLowerBound, varUpperBound);
OptimizationAlgorithm.getRefValue(colValues, secretKey), varLowerBound, varUpperBound);
maxList.add(tmp);
break;
default:
Expand All @@ -170,6 +191,8 @@ private void encodeSingleBit(ArrayList<ArrayList<String>> partition, double secr
case DOUBLE:
placeholder = "%."+len+"f";
break;
default:
break;
}

// 写回partition
Expand Down

0 comments on commit f3b6f7c

Please sign in to comment.