From 0115f37a031df9f9ec605bdefa3e29b560b2497d Mon Sep 17 00:00:00 2001 From: narnolddd Date: Tue, 18 Apr 2023 11:27:52 +0100 Subject: [PATCH] Make sure degree model components working with directed networks and that centre node choices are acknowledged in the case of internal stars --- feta/FitAndCloneRunner.java | 57 +++++++++++-------- feta/actions/FitMixedModel.java | 10 +++- feta/objectmodels/FullObjectModel.java | 8 +++ feta/objectmodels/MixedModel.java | 15 ++++- .../components/DegreeModelComponent.java | 8 +-- feta/objectmodels/components/DegreePower.java | 23 ++++++-- feta/operations/Operation.java | 3 +- feta/operations/Star.java | 6 +- feta/parsenet/ParseNet.java | 7 ++- feta/parsenet/ParseNetDirected.java | 1 + feta/parsenet/ParseNetUndirected.java | 1 + 11 files changed, 94 insertions(+), 45 deletions(-) diff --git a/feta/FitAndCloneRunner.java b/feta/FitAndCloneRunner.java index a297551..8de73cc 100644 --- a/feta/FitAndCloneRunner.java +++ b/feta/FitAndCloneRunner.java @@ -8,9 +8,7 @@ import feta.network.UndirectedNetwork; import feta.objectmodels.FullObjectModel; import feta.objectmodels.MixedModel; -import feta.objectmodels.components.DegreeModelComponent; -import feta.objectmodels.components.ObjectModelComponent; -import feta.objectmodels.components.RandomAttachment; +import feta.objectmodels.components.*; import feta.operations.MixedOperations; import feta.operations.OperationModel; import feta.readnet.ReadNet; @@ -25,31 +23,42 @@ public class FitAndCloneRunner { public static void main(String[] args) { // Read in network to be fitted ReadNet reader = new ReadNetCSV("typeTest.dat"," ",true,0,1,2,3,4); - Network net = new DirectedNetwork(reader, true); + double[] degreePowerParms = new double[] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.2}; - // Specify object model to be tested - ArrayList components = new ArrayList<>() { - { - add( new DegreeModelComponent()); - add( new RandomAttachment()); - } - }; + double bestLikelihood = 0.0; + FullObjectModel bestObm = null; + for (double d: degreePowerParms) { + Network net = new DirectedNetwork(reader, true); - MixedModel model = new MixedModel(components); + // Specify object model to be tested + ArrayList components = new ArrayList<>() { + { + add(new RandomAttachment()); + add(new DegreeModelComponent()); + } + }; - // Fit mixed model - FitMixedModel fit = new FitMixedModel(net,model,100,10,false); - fit.execute(); + MixedModel model = new MixedModel(components); - // Get out parsed operation model and fitted object model - OperationModel om = new MixedOperations(fit.getParsedOperations()); - FullObjectModel obm = fit.getFittedModel(); + // Fit mixed model + FitMixedModel fit = new FitMixedModel(net, model, 100, 10, false); + fit.execute(); - // Grow network from this operation and object model and write to csv - Network grownNet = new DirectedNetwork(reader,true); - Grow grow = new Grow(grownNet,obm,om,10L,100L); - grow.execute(); - WriteNet writer = new WriteNetCSV(grownNet, " ", "testOutput.csv"); - writer.write(); + // Get out parsed operation model and fitted object model + OperationModel om = new MixedOperations(fit.getParsedOperations()); + FullObjectModel obm = fit.getFittedModel(); + double currentLikelihood = fit.getBestLikelihood(); + if (currentLikelihood > bestLikelihood) { + bestLikelihood = currentLikelihood; + bestObm = obm; + } + } + System.out.println(bestObm); +// // Grow network from this operation and object model and write to csv +// Network grownNet = new DirectedNetwork(reader, true); +// Grow grow = new Grow(grownNet, bestObm, om, 10L, 100L); +// grow.execute(); +// WriteNet writer = new WriteNetCSV(grownNet, " ", "testOutput.csv"); +// writer.write(); } } diff --git a/feta/actions/FitMixedModel.java b/feta/actions/FitMixedModel.java index fef3ce5..3a0f079 100644 --- a/feta/actions/FitMixedModel.java +++ b/feta/actions/FitMixedModel.java @@ -28,6 +28,7 @@ public class FitMixedModel extends SimpleAction { public FullObjectModel objectModel_; public int granularity_; public List configs_; + private double bestLikelihood_; public long startTime_=10; private boolean orderedData_ = false; private Random random_; @@ -63,7 +64,7 @@ private static List generatePartitions(int n, int k) { return parts; } List newParts = new ArrayList<>(); - for (int l = 0; l < n; l++) { + for (int l = 0; l <= n; l++) { List oldParts = generatePartitions(n-l,k-1); for (int[] partition: oldParts) { int[] newPartition = new int[partition.length+1]; @@ -89,6 +90,7 @@ private ArrayList generateModels() { public void execute(){ ParseNet parser; + network_.buildUpTo(startTime_); operationsExtracted_= new ArrayList<>(); if (network_ instanceof UndirectedNetwork) { parser = new ParseNetUndirected((UndirectedNetwork) network_); @@ -168,6 +170,8 @@ public String getLikelihoods(ParseNet parser, long start, long end, int[] totalC String toPrint = "{\"start\":"+start+", \"c0max\" : "+maxLike+ ", \"raw\": "+bestRaw+", \"choices\": "+noChoices+", \"models\": "; + bestLikelihood_=maxLike; + String[] models = new String[bestConfig.length]; for (int i = 0; i < bestConfig.length; i++) { models[i]="{\""+obm.components_.get(i)+"\": "+bestConfig[i]+"}"; @@ -185,6 +189,10 @@ public FullObjectModel getFittedModel() { return objectModel_; } + public double getBestLikelihood() { + return bestLikelihood_; + } + public void updateLikelihoods(Operation op, MixedModel obm) { op.setRandom(random_); op.setNodeChoices(orderedData_); diff --git a/feta/objectmodels/FullObjectModel.java b/feta/objectmodels/FullObjectModel.java index 0e757d5..8873dd8 100644 --- a/feta/objectmodels/FullObjectModel.java +++ b/feta/objectmodels/FullObjectModel.java @@ -102,4 +102,12 @@ public void parseObjectModels(JSONArray model) { } checkValid(); } + + @Override + public String toString() { + return "FullObjectModel{" + + "objectModels_=" + objectModels_ + + ", timeToOM_=" + timeToOM_ + + '}'; + } } \ No newline at end of file diff --git a/feta/objectmodels/MixedModel.java b/feta/objectmodels/MixedModel.java index 99790ac..7927e5d 100644 --- a/feta/objectmodels/MixedModel.java +++ b/feta/objectmodels/MixedModel.java @@ -19,14 +19,19 @@ public class MixedModel { public ArrayList components_; private double[] weights_; private boolean checkWeights_; + private int counter; private HashMap likelihoods_; // "build from scratch" constructor - public MixedModel() {components_=new ArrayList();} + public MixedModel() { + components_=new ArrayList(); + counter = 0; + } // Constructor for FitMixedModel public MixedModel(ArrayList components) { components_ = components; + counter = 0; } // Constructor for Grow/Likelihood @@ -34,6 +39,7 @@ public MixedModel(ArrayList components, double[] weights) components_ = components; weights_ = weights; checkValid(); + counter = 0; } public HashMap getLikelihoods () { @@ -101,7 +107,8 @@ public final int nodeDrawWithoutReplacement(Network net, HashSet availa calcNormalisation(net, availableNodes); } else { updateNormalisation(net, availableNodes, seedNode); - checkUpdatedNorm(net,availableNodes); + if (counter < 50) + checkUpdatedNorm(net,availableNodes); } double r = Math.random(); double weightSoFar = 0.0; @@ -116,7 +123,8 @@ public final int nodeDrawWithoutReplacement(Network net, HashSet availa } public int[] drawMultipleNodesWithoutReplacement(Network net, int sampleSize, HashSet availableNodes) { - checkNorm(net); + if (counter < 50) + checkNorm(net); int[] chosenNodes = new int[sampleSize]; if (sampleSize == 0) return chosenNodes; @@ -131,6 +139,7 @@ public int[] drawMultipleNodesWithoutReplacement(Network net, int sampleSize, Ha chosenNodes[i] = node; seedNode = node; } + counter+=1; return chosenNodes; } diff --git a/feta/objectmodels/components/DegreeModelComponent.java b/feta/objectmodels/components/DegreeModelComponent.java index d771d1e..be50e7c 100644 --- a/feta/objectmodels/components/DegreeModelComponent.java +++ b/feta/objectmodels/components/DegreeModelComponent.java @@ -74,15 +74,15 @@ public void updateNormalisation(DirectedNetwork net, HashSet availableN } public double calcProbability(UndirectedNetwork net, int node) { - if (tempConstant_==0.0){ - return 0.0; + if (random_){ + return 1.0/tempConstant_; } return net.getDegree(node)/tempConstant_; } public double calcProbability(DirectedNetwork net, int node) { - if (normalisationConstant_==0.0){ - return 0.0; + if (random_){ + return 1.0/tempConstant_; } if (useInDegree_) { return (net.getInDegree(node))/tempConstant_; diff --git a/feta/objectmodels/components/DegreePower.java b/feta/objectmodels/components/DegreePower.java index 48622b9..badb159 100644 --- a/feta/objectmodels/components/DegreePower.java +++ b/feta/objectmodels/components/DegreePower.java @@ -8,11 +8,21 @@ public class DegreePower extends ObjectModelComponent { + public DegreePower(){}; + public DegreePower(double power) { + power_=power; + } + public DegreePower(double power, boolean useInDegree) { + power_=power; + useInDegree_=useInDegree; + } + public double power_=1.0; public boolean useInDegree_=true; @Override public void calcNormalisation(UndirectedNetwork net, int sourceNode, HashSet availableNodes) { + random_=false; double degSum = 0.0; for (int node: availableNodes) { degSum += Math.pow(net.getDegree(node), power_); @@ -30,6 +40,7 @@ public void calcNormalisation(UndirectedNetwork net, int sourceNode, HashSet availableNodes) { + random_=false; double degSum = 0.0; for (int node: availableNodes) { if (useInDegree_) { @@ -61,17 +72,17 @@ public void updateNormalisation(UndirectedNetwork net, HashSet availabl } public double calcProbability(UndirectedNetwork net, int node) { - if (tempConstant_==0.0) - return 0.0; + if (random_) + return 1.0/tempConstant_; // tempConstant is never zero if random is true return Math.pow(net.getDegree(node), power_)/tempConstant_; } public double calcProbability(DirectedNetwork net, int node) { - if (tempConstant_==0.0) - return 0.0; + if (random_) + return 1.0/tempConstant_; // tempConstant is never zero if random is true if (useInDegree_) - return Math.pow(net.getInDegree(node)+1,power_)/tempConstant_; - return Math.pow(net.getOutDegree(node)+1,power_)/tempConstant_; + return Math.pow(net.getInDegree(node),power_)/tempConstant_; + return Math.pow(net.getOutDegree(node),power_)/tempConstant_; } public void parseJSON(JSONObject params) { diff --git a/feta/operations/Operation.java b/feta/operations/Operation.java index 978d7d6..851a5ca 100644 --- a/feta/operations/Operation.java +++ b/feta/operations/Operation.java @@ -15,7 +15,7 @@ public abstract class Operation { ArrayList nodeChoices_; ArrayList nodeOrders_; private long time_; - private int noChoices_=0; + protected int noChoices_=0; private Random generator_; private boolean censored_= false; @@ -106,7 +106,6 @@ ArrayList generateOrdersFromOperation() { } public void filterNodeChoices() { - noChoices_=0; ArrayList newChoices = new ArrayList(); for (int[] nodeSet: nodeChoices_) { int[] copy = Methods.removeNegativeNumbers(nodeSet); diff --git a/feta/operations/Star.java b/feta/operations/Star.java index 373b084..83f6d58 100644 --- a/feta/operations/Star.java +++ b/feta/operations/Star.java @@ -70,9 +70,9 @@ public void chooseNodes(Network net, MixedModel obm) throws Exception { public void setNodeChoices(boolean orderedData) { nodeChoices_= new ArrayList(); -// if (internal_) { -// nodeChoices_.add(new int[] {centreNode_}); -// } + if (internal_) { + noChoices_+=1; + } if (orderedData) { for (int node: leafNodes_) { nodeChoices_.add(new int[] {node}); diff --git a/feta/parsenet/ParseNet.java b/feta/parsenet/ParseNet.java index 7a76932..276d3d5 100644 --- a/feta/parsenet/ParseNet.java +++ b/feta/parsenet/ParseNet.java @@ -22,8 +22,6 @@ public abstract class ParseNet { protected HashSet processedNodes_; public void parseNetwork(long start, long end) { - processedNodes_=new HashSet<>(); - /* Get string set of already processed nodes */ net_.buildUpTo(start); for (int node: net_.getNodeListCopy()) { @@ -43,6 +41,11 @@ public void parseNetwork(long start, long end) { } } + protected void initialiseProcessedNodeSet() { + for (int i = 0; i < net_.noNodes_; i++) { + processedNodes_.add(net_.nodeNoToName(i)); + } + } public ArrayList getNextLinkSet(ArrayList links) { ArrayList linkSet = new ArrayList(); diff --git a/feta/parsenet/ParseNetDirected.java b/feta/parsenet/ParseNetDirected.java index 6b45562..d3fc51a 100644 --- a/feta/parsenet/ParseNetDirected.java +++ b/feta/parsenet/ParseNetDirected.java @@ -18,6 +18,7 @@ public ParseNetDirected(DirectedNetwork net){ operations_ = new ArrayList(); processedNodes_= new HashSet<>(); net_=net; + initialiseProcessedNodeSet(); } public ArrayList parseNewLinks(ArrayList links, Network net) { diff --git a/feta/parsenet/ParseNetUndirected.java b/feta/parsenet/ParseNetUndirected.java index 7cdc6b8..b12b238 100644 --- a/feta/parsenet/ParseNetUndirected.java +++ b/feta/parsenet/ParseNetUndirected.java @@ -20,6 +20,7 @@ public ParseNetUndirected(UndirectedNetwork network) { operations_= new ArrayList(); processedNodes_= new HashSet<>(); net_=network; + initialiseProcessedNodeSet(); } public ArrayList parseNewLinks(ArrayList links, Network net) {