Skip to content

Commit

Permalink
Make sure degree model components working with directed networks and …
Browse files Browse the repository at this point in the history
…that centre node choices are acknowledged in the case of internal stars
  • Loading branch information
narnolddd committed Apr 18, 2023
1 parent 7c0c46e commit 0115f37
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 45 deletions.
57 changes: 33 additions & 24 deletions feta/FitAndCloneRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
import feta.network.UndirectedNetwork;
import feta.objectmodels.FullObjectModel;
import feta.objectmodels.MixedModel;
import feta.objectmodels.components.DegreeModelComponent;
import feta.objectmodels.components.ObjectModelComponent;
import feta.objectmodels.components.RandomAttachment;
import feta.objectmodels.components.*;
import feta.operations.MixedOperations;
import feta.operations.OperationModel;
import feta.readnet.ReadNet;
Expand All @@ -25,31 +23,42 @@ public class FitAndCloneRunner {
public static void main(String[] args) {
// Read in network to be fitted
ReadNet reader = new ReadNetCSV("typeTest.dat"," ",true,0,1,2,3,4);
Network net = new DirectedNetwork(reader, true);
double[] degreePowerParms = new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.2};

// Specify object model to be tested
ArrayList<ObjectModelComponent> components = new ArrayList<>() {
{
add( new DegreeModelComponent());
add( new RandomAttachment());
}
};
double bestLikelihood = 0.0;
FullObjectModel bestObm = null;
for (double d: degreePowerParms) {
Network net = new DirectedNetwork(reader, true);

MixedModel model = new MixedModel(components);
// Specify object model to be tested
ArrayList<ObjectModelComponent> components = new ArrayList<>() {
{
add(new RandomAttachment());
add(new DegreeModelComponent());
}
};

// Fit mixed model
FitMixedModel fit = new FitMixedModel(net,model,100,10,false);
fit.execute();
MixedModel model = new MixedModel(components);

// Get out parsed operation model and fitted object model
OperationModel om = new MixedOperations(fit.getParsedOperations());
FullObjectModel obm = fit.getFittedModel();
// Fit mixed model
FitMixedModel fit = new FitMixedModel(net, model, 100, 10, false);
fit.execute();

// Grow network from this operation and object model and write to csv
Network grownNet = new DirectedNetwork(reader,true);
Grow grow = new Grow(grownNet,obm,om,10L,100L);
grow.execute();
WriteNet writer = new WriteNetCSV(grownNet, " ", "testOutput.csv");
writer.write();
// Get out parsed operation model and fitted object model
OperationModel om = new MixedOperations(fit.getParsedOperations());
FullObjectModel obm = fit.getFittedModel();
double currentLikelihood = fit.getBestLikelihood();
if (currentLikelihood > bestLikelihood) {
bestLikelihood = currentLikelihood;
bestObm = obm;
}
}
System.out.println(bestObm);
// // Grow network from this operation and object model and write to csv
// Network grownNet = new DirectedNetwork(reader, true);
// Grow grow = new Grow(grownNet, bestObm, om, 10L, 100L);
// grow.execute();
// WriteNet writer = new WriteNetCSV(grownNet, " ", "testOutput.csv");
// writer.write();
}
}
10 changes: 9 additions & 1 deletion feta/actions/FitMixedModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class FitMixedModel extends SimpleAction {
public FullObjectModel objectModel_;
public int granularity_;
public List<int[]> configs_;
private double bestLikelihood_;
public long startTime_=10;
private boolean orderedData_ = false;
private Random random_;
Expand Down Expand Up @@ -63,7 +64,7 @@ private static List<int[]> generatePartitions(int n, int k) {
return parts;
}
List<int[]> newParts = new ArrayList<>();
for (int l = 0; l < n; l++) {
for (int l = 0; l <= n; l++) {
List<int[]> oldParts = generatePartitions(n-l,k-1);
for (int[] partition: oldParts) {
int[] newPartition = new int[partition.length+1];
Expand All @@ -89,6 +90,7 @@ private ArrayList<double[]> generateModels() {

public void execute(){
ParseNet parser;
network_.buildUpTo(startTime_);
operationsExtracted_= new ArrayList<>();
if (network_ instanceof UndirectedNetwork) {
parser = new ParseNetUndirected((UndirectedNetwork) network_);
Expand Down Expand Up @@ -168,6 +170,8 @@ public String getLikelihoods(ParseNet parser, long start, long end, int[] totalC
String toPrint = "{\"start\":"+start+", \"c0max\" : "+maxLike+
", \"raw\": "+bestRaw+", \"choices\": "+noChoices+", \"models\": ";

bestLikelihood_=maxLike;

String[] models = new String[bestConfig.length];
for (int i = 0; i < bestConfig.length; i++) {
models[i]="{\""+obm.components_.get(i)+"\": "+bestConfig[i]+"}";
Expand All @@ -185,6 +189,10 @@ public FullObjectModel getFittedModel() {
return objectModel_;
}

public double getBestLikelihood() {
return bestLikelihood_;
}

public void updateLikelihoods(Operation op, MixedModel obm) {
op.setRandom(random_);
op.setNodeChoices(orderedData_);
Expand Down
8 changes: 8 additions & 0 deletions feta/objectmodels/FullObjectModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,12 @@ public void parseObjectModels(JSONArray model) {
}
checkValid();
}

@Override
public String toString() {
return "FullObjectModel{" +
"objectModels_=" + objectModels_ +
", timeToOM_=" + timeToOM_ +
'}';
}
}
15 changes: 12 additions & 3 deletions feta/objectmodels/MixedModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,27 @@ public class MixedModel {
public ArrayList<ObjectModelComponent> components_;
private double[] weights_;
private boolean checkWeights_;
private int counter;
private HashMap<double[], Double> likelihoods_;

// "build from scratch" constructor
public MixedModel() {components_=new ArrayList<ObjectModelComponent>();}
public MixedModel() {
components_=new ArrayList<ObjectModelComponent>();
counter = 0;
}

// Constructor for FitMixedModel
public MixedModel(ArrayList<ObjectModelComponent> components) {
components_ = components;
counter = 0;
}

// Constructor for Grow/Likelihood
public MixedModel(ArrayList<ObjectModelComponent> components, double[] weights) {
components_ = components;
weights_ = weights;
checkValid();
counter = 0;
}

public HashMap<double[], Double> getLikelihoods () {
Expand Down Expand Up @@ -101,7 +107,8 @@ public final int nodeDrawWithoutReplacement(Network net, HashSet<Integer> availa
calcNormalisation(net, availableNodes);
} else {
updateNormalisation(net, availableNodes, seedNode);
checkUpdatedNorm(net,availableNodes);
if (counter < 50)
checkUpdatedNorm(net,availableNodes);
}
double r = Math.random();
double weightSoFar = 0.0;
Expand All @@ -116,7 +123,8 @@ public final int nodeDrawWithoutReplacement(Network net, HashSet<Integer> availa
}

public int[] drawMultipleNodesWithoutReplacement(Network net, int sampleSize, HashSet<Integer> availableNodes) {
checkNorm(net);
if (counter < 50)
checkNorm(net);
int[] chosenNodes = new int[sampleSize];
if (sampleSize == 0)
return chosenNodes;
Expand All @@ -131,6 +139,7 @@ public int[] drawMultipleNodesWithoutReplacement(Network net, int sampleSize, Ha
chosenNodes[i] = node;
seedNode = node;
}
counter+=1;
return chosenNodes;
}

Expand Down
8 changes: 4 additions & 4 deletions feta/objectmodels/components/DegreeModelComponent.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ public void updateNormalisation(DirectedNetwork net, HashSet<Integer> availableN
}

public double calcProbability(UndirectedNetwork net, int node) {
if (tempConstant_==0.0){
return 0.0;
if (random_){
return 1.0/tempConstant_;
}
return net.getDegree(node)/tempConstant_;
}

public double calcProbability(DirectedNetwork net, int node) {
if (normalisationConstant_==0.0){
return 0.0;
if (random_){
return 1.0/tempConstant_;
}
if (useInDegree_) {
return (net.getInDegree(node))/tempConstant_;
Expand Down
23 changes: 17 additions & 6 deletions feta/objectmodels/components/DegreePower.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@

public class DegreePower extends ObjectModelComponent {

public DegreePower(){};
public DegreePower(double power) {
power_=power;
}
public DegreePower(double power, boolean useInDegree) {
power_=power;
useInDegree_=useInDegree;
}

public double power_=1.0;
public boolean useInDegree_=true;

@Override
public void calcNormalisation(UndirectedNetwork net, int sourceNode, HashSet<Integer> availableNodes) {
random_=false;
double degSum = 0.0;
for (int node: availableNodes) {
degSum += Math.pow(net.getDegree(node), power_);
Expand All @@ -30,6 +40,7 @@ public void calcNormalisation(UndirectedNetwork net, int sourceNode, HashSet<Int

@Override
public void calcNormalisation(DirectedNetwork net, int sourceNode, HashSet<Integer> availableNodes) {
random_=false;
double degSum = 0.0;
for (int node: availableNodes) {
if (useInDegree_) {
Expand Down Expand Up @@ -61,17 +72,17 @@ public void updateNormalisation(UndirectedNetwork net, HashSet<Integer> availabl
}

public double calcProbability(UndirectedNetwork net, int node) {
if (tempConstant_==0.0)
return 0.0;
if (random_)
return 1.0/tempConstant_; // tempConstant is never zero if random is true
return Math.pow(net.getDegree(node), power_)/tempConstant_;
}

public double calcProbability(DirectedNetwork net, int node) {
if (tempConstant_==0.0)
return 0.0;
if (random_)
return 1.0/tempConstant_; // tempConstant is never zero if random is true
if (useInDegree_)
return Math.pow(net.getInDegree(node)+1,power_)/tempConstant_;
return Math.pow(net.getOutDegree(node)+1,power_)/tempConstant_;
return Math.pow(net.getInDegree(node),power_)/tempConstant_;
return Math.pow(net.getOutDegree(node),power_)/tempConstant_;
}

public void parseJSON(JSONObject params) {
Expand Down
3 changes: 1 addition & 2 deletions feta/operations/Operation.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public abstract class Operation {
ArrayList<int[]> nodeChoices_;
ArrayList <int[]> nodeOrders_;
private long time_;
private int noChoices_=0;
protected int noChoices_=0;
private Random generator_;
private boolean censored_= false;

Expand Down Expand Up @@ -106,7 +106,6 @@ ArrayList<int[]> generateOrdersFromOperation() {
}

public void filterNodeChoices() {
noChoices_=0;
ArrayList<int[]> newChoices = new ArrayList<int[]>();
for (int[] nodeSet: nodeChoices_) {
int[] copy = Methods.removeNegativeNumbers(nodeSet);
Expand Down
6 changes: 3 additions & 3 deletions feta/operations/Star.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ public void chooseNodes(Network net, MixedModel obm) throws Exception {

public void setNodeChoices(boolean orderedData) {
nodeChoices_= new ArrayList<int[]>();
// if (internal_) {
// nodeChoices_.add(new int[] {centreNode_});
// }
if (internal_) {
noChoices_+=1;
}
if (orderedData) {
for (int node: leafNodes_) {
nodeChoices_.add(new int[] {node});
Expand Down
7 changes: 5 additions & 2 deletions feta/parsenet/ParseNet.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ public abstract class ParseNet {
protected HashSet<String> processedNodes_;

public void parseNetwork(long start, long end) {
processedNodes_=new HashSet<>();

/* Get string set of already processed nodes */
net_.buildUpTo(start);
for (int node: net_.getNodeListCopy()) {
Expand All @@ -43,6 +41,11 @@ public void parseNetwork(long start, long end) {
}
}

protected void initialiseProcessedNodeSet() {
for (int i = 0; i < net_.noNodes_; i++) {
processedNodes_.add(net_.nodeNoToName(i));
}
}

public ArrayList<Link> getNextLinkSet(ArrayList<Link> links) {
ArrayList<Link> linkSet = new ArrayList<Link>();
Expand Down
1 change: 1 addition & 0 deletions feta/parsenet/ParseNetDirected.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public ParseNetDirected(DirectedNetwork net){
operations_ = new ArrayList<Operation>();
processedNodes_= new HashSet<>();
net_=net;
initialiseProcessedNodeSet();
}

public ArrayList<Operation> parseNewLinks(ArrayList <Link> links, Network net) {
Expand Down
1 change: 1 addition & 0 deletions feta/parsenet/ParseNetUndirected.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public ParseNetUndirected(UndirectedNetwork network) {
operations_= new ArrayList<Operation>();
processedNodes_= new HashSet<>();
net_=network;
initialiseProcessedNodeSet();
}

public ArrayList<Operation> parseNewLinks(ArrayList <Link> links, Network net) {
Expand Down

0 comments on commit 0115f37

Please sign in to comment.