forked from stanford-ppl/spatial
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add congestion model using lattice regression, apps for collecting da…
…ta, training a hypercube, inferring in scala, and some refactoring
- Loading branch information
Showing
20 changed files
with
710 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
package models | ||
|
||
import java.io.File | ||
import java.io.PrintWriter | ||
import utils.io.files._ | ||
import utils.math.{CombinationTree, ReduceTree} | ||
|
||
import scala.io.Source | ||
|
||
object CongestionModel { | ||
|
||
abstract class FeatureVec[T] { | ||
def loads: T | ||
def stores: T | ||
def gateds: T | ||
def outerIters: T | ||
def innerIters: T | ||
def toSeq: Seq[T] = Seq(stores, outerIters, loads, innerIters, gateds) | ||
} | ||
case class RawFeatureVec(loads: Double, stores: Double, gateds: Double, outerIters: Double, innerIters: Double) extends FeatureVec[Double] | ||
case class CalibFeatureVec(loads: Double, stores: Double, gateds: Double, outerIters: Double, innerIters: Double) extends FeatureVec[Double] | ||
|
||
// Set up lattice properties | ||
val feature_dims = 5 | ||
val lattice_rank = 5 | ||
val lattice_size = Seq(3,3,3,3,3) | ||
val num_keypoints = 8 | ||
val num_lattices = 1 | ||
var model: String = "" | ||
|
||
// Derive lattice properties | ||
val sizes = scala.Array.tabulate(lattice_rank){i => lattice_size(i)} | ||
val dimensions = sizes.length | ||
val params_per_lattice = sizes.product | ||
val strides: scala.Array[Int] = scala.Array.fill(dimensions){1} | ||
val nparams = num_lattices * params_per_lattice | ||
|
||
// Grab lattice params | ||
lazy val loads_keypoints_inputs = ModelData.loads_keypoints_inputs(model).map(_.toDouble) //loadCSVNow[Int](s"../data/${model}/CALIBRATOR_INPUT_PARAMS/loads_keypoints_inputs.csv", ","){x => x.toDouble} | ||
lazy val loads_keypoints_outputs = ModelData.loads_keypoints_outputs(model).map(_.toDouble) //loadCSVNow[Double](s"../data/${model}/CALIBRATOR_OUTPUT_PARAMS/loads_keypoints_outputs.csv", ","){x => x.toDouble} | ||
lazy val stores_keypoints_inputs = ModelData.stores_keypoints_inputs(model).map(_.toDouble) //loadCSVNow[Int](s"../data/${model}/CALIBRATOR_INPUT_PARAMS/stores_keypoints_inputs.csv", ","){x => x.toDouble} | ||
lazy val stores_keypoints_outputs = ModelData.stores_keypoints_outputs(model).map(_.toDouble) //loadCSVNow[Double](s"../data/${model}/CALIBRATOR_OUTPUT_PARAMS/stores_keypoints_outputs.csv", ","){x => x.toDouble} | ||
lazy val gateds_keypoints_inputs = ModelData.gateds_keypoints_inputs(model).map(_.toDouble) //loadCSVNow[Int](s"../data/${model}/CALIBRATOR_INPUT_PARAMS/gateds_keypoints_inputs.csv", ","){x => x.toDouble} | ||
lazy val gateds_keypoints_outputs = ModelData.gateds_keypoints_outputs(model).map(_.toDouble) //loadCSVNow[Double](s"../data/${model}/CALIBRATOR_OUTPUT_PARAMS/gateds_keypoints_outputs.csv", ","){x => x.toDouble} | ||
lazy val outerIters_keypoints_inputs = ModelData.outerIters_keypoints_inputs(model).map(_.toDouble) //loadCSVNow[Int](s"../data/${model}/CALIBRATOR_INPUT_PARAMS/outerIters_keypoints_inputs.csv", ","){x => x.toDouble} | ||
lazy val outerIters_keypoints_outputs = ModelData.outerIters_keypoints_outputs(model).map(_.toDouble) //loadCSVNow[Double](s"../data/${model}/CALIBRATOR_OUTPUT_PARAMS/outerIters_keypoints_outputs.csv", ","){x => x.toDouble} | ||
lazy val innerIters_keypoints_inputs = ModelData.innerIters_keypoints_inputs(model).map(_.toDouble) //loadCSVNow[Int](s"../data/${model}/CALIBRATOR_INPUT_PARAMS/innerIters_keypoints_inputs.csv", ","){x => x.toDouble} | ||
lazy val innerIters_keypoints_outputs = ModelData.innerIters_keypoints_outputs(model).map(_.toDouble) //loadCSVNow[Double](s"../data/${model}/CALIBRATOR_OUTPUT_PARAMS/innerIters_keypoints_outputs.csv", ","){x => x.toDouble} | ||
lazy val params = ModelData.params(model).map(_.toDouble) //loadCSVNow[Double](s"../data/${model}/LATTICE_PARAMS.csv", ","){x => x.toDouble} | ||
|
||
/** Calibrate one element in a feature */ | ||
def calibrate(inputs: Seq[Double], outputs: Seq[Double], feature: Double, max_dim: Int): Double = { | ||
val pwl = if (inputs.nonEmpty) { | ||
if (feature < inputs.head) outputs.head | ||
else if (feature >= inputs.last) outputs.last | ||
else (0 until inputs.size-1).collect{case i if (inputs(i) <= feature && feature < inputs(i+1)) => | ||
if (feature == inputs(i) || outputs(i) == outputs(i+1)) outputs(i) | ||
else outputs(i) + (feature - inputs(i)) * ((outputs(i+1) - outputs(i)) / (inputs(i+1) - inputs(i))) | ||
}.headOption.getOrElse(outputs.last) | ||
} | ||
else 0.0 | ||
(0.0 max pwl) min {max_dim-1}.toDouble | ||
|
||
} | ||
|
||
/** Run raw features through calibrators */ | ||
def calibrate_features(features: RawFeatureVec): CalibFeatureVec = { | ||
val loads = calibrate(loads_keypoints_inputs, loads_keypoints_outputs, features.loads, lattice_size(0)) | ||
val stores = calibrate(stores_keypoints_inputs, stores_keypoints_outputs, features.stores, lattice_size(1)) | ||
val gateds = calibrate(gateds_keypoints_inputs, gateds_keypoints_outputs, features.gateds, lattice_size(2)) | ||
val outerIters = calibrate(outerIters_keypoints_inputs, outerIters_keypoints_outputs, features.outerIters, lattice_size(3)) | ||
val innerIters = calibrate(innerIters_keypoints_inputs, innerIters_keypoints_outputs, features.innerIters, lattice_size(4)) | ||
val calib = CalibFeatureVec(loads = loads, stores = stores, gateds = gateds, outerIters = outerIters, innerIters = innerIters) | ||
calib | ||
} | ||
|
||
/** Get all corners in the hypercube */ | ||
def allCorners(maxes: Seq[scala.Int], partials: Seq[Seq[scala.Int]] = Seq(Seq.empty)): Seq[Seq[scala.Int]] = maxes match { | ||
case Nil => Nil | ||
case h::tail if tail.nonEmpty => (0 to h).flatMap{i => allCorners(tail, partials.map(_ ++ Seq(i)))} | ||
case h::tail if tail.isEmpty => (0 to h).flatMap{i => partials.map(_ ++ Seq(i))} | ||
} | ||
|
||
/** Run calibrated features through hypercube interp */ | ||
def hypercube_features(features: CalibFeatureVec): Double = { | ||
val residualPairs: Seq[Seq[Double]] = Seq.tabulate(dimensions) {i => | ||
val x = features.toSeq(i) | ||
Seq(x % 1.0, 1.0-(x%1.0)) | ||
} | ||
|
||
// Compute all hypervolumes in binary counting order (000, 001, 010, 011, etc..) | ||
val hypervolumes: Seq[Double] = CombinationTree[Double](residualPairs:_*)(_*_) | ||
// Compute hypercube origin | ||
val base: Seq[Int] = Array.tabulate(dimensions) {x => features.toSeq(x).toInt} | ||
// Get all vertices of a hypercube and reverse so that these are opposite the hypervolumes | ||
val corners: Seq[Seq[scala.Int]] = allCorners(Seq.fill(dimensions)(1)).reverse | ||
|
||
// Get flat index for each (corner + origin) | ||
val indices: Seq[Int] = corners map { c => | ||
val corner = (base zip c.map(_.toInt)) map {case (a,b) => a + b} | ||
corner.zipWithIndex.map { case (cc, i) => | ||
cc * lattice_size.drop(i+1).product | ||
} reduce {_+_} | ||
} | ||
|
||
// Get weighted sum | ||
val x = hypervolumes.map(_.toDouble).zip(indices).collect{case (hv,i) if hv > 0 => hv * params(i)}.sum | ||
x | ||
} | ||
|
||
/** Evaluate model on features (loads: Int, stores: Int, gateds: Int, outerIters: Int, innerIters: Int) */ | ||
def evaluate(features: RawFeatureVec, typ: Runtime.CtrlSchedule): Int = { | ||
model = typ.toString | ||
|
||
val calibrated_features = calibrate_features(features) | ||
val result = hypercube_features(calibrated_features) | ||
result.toInt | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package models | ||
|
||
object DenseLoadModelData{ | ||
val params = Seq(1607.216,648.4963,368.1771,2122.0925,962.271,645.5907,2779.559,1537.6995,1151.2914,1952.1106,835.7142,477.6142,2748.6926,1396.6589,1001.90424,3878.6516,2779.836,2351.3564,2369.9692,1175.3682,768.8471,3480.4053,2101.2588,1653.4564,5080.296,4317.6562,4028.652,2168.0515,1119.1941,813.2784,3081.2996,2046.3541,1789.7451,4216.033,3374.7634,3254.7761,2767.3926,1598.2214,1507.8003,4359.934,3766.821,4568.9688,6529.0073,7739.8438,9023.787,3434.2393,2286.3079,2031.8228,5710.4824,5762.401,7332.706,8953.163,12407.342,14173.723,2715.7,1580.7279,1201.4835,4044.112,3116.2593,2813.7883,5683.8906,5301.641,5262.5186,3592.2583,2418.7534,2497.2637,6044.716,6555.872,7962.523,9255.937,12899.41,15427.333,4551.948,3530.7036,3333.6372,8086.4277,9948.789,13019.589,12889.885,20372.781,24641.908,1174.5428,234.08334,274.31076,1609.1492,337.94016,567.33105,2230.7122,577.98737,838.2734,1460.9763,311.3316,363.87692,2188.8416,533.4591,662.51843,3484.0576,1916.9586,1818.1388,1874.7726,523.9037,564.71954,2965.5408,880.47864,882.2886,4906.0356,3193.3884,3408.1394,1741.262,463.4396,595.77625,2689.5051,1321.0378,1396.2699,3907.4185,1909.868,2397.1787,2304.4019,362.26767,1226.2435,4187.6655,1305.0627,4309.4517,7194.558,7884.9507,9974.382,3009.305,1083.6733,1113.5845,5840.6875,2930.077,7385.636,10850.563,16913.354,16858.924,2247.3967,791.38,858.7941,3719.193,1948.8992,2226.4827,5629.123,3452.2043,3874.473,3118.5774,469.60852,2149.3792,6385.8306,4279.0767,8109.4897,11178.077,15298.498,17860.398,4196.5537,1890.1986,1851.9263,9041.918,6623.696,13867.92,16975.623,29861.576,31802.404,831.4097,259.87247,267.80084,1248.092,555.44366,567.14636,1860.8497,798.1045,905.307,1086.6595,346.6293,344.99188,1854.3694,690.3772,784.9425,3144.214,1757.7875,1752.2285,1498.828,561.6,575.25995,2666.876,946.3104,971.49084,4734.354,3368.1318,3339.9658,1431.4207,524.507,565.54034,2471.8838,1299.0255,1337.3829,3904.1943,2384.8994,2473.7197,2204.7664,1252.9379,1274.327,4855.465,4764.377,4835.339,8253.649,10080.159,10024.748,2903.5566,1147.0736,1177.5868,7089.2153,7457.9,7662.5674,12129.67,15735.41,15848.641,1962.3672,809.2306,860.36786,3638.2449,2047.3785,2044.6423,5899.906,4045.6597,4138.4985,3236.4773,2058.7397,2054.691,7665.2607,7994.1255,8072.894,13304.945,17878.521,17747.057,4355.019,2054.842,2057.5217,11745.724,14810.251,14907.296,20321.078,30924.383,30947.24) | ||
val innerIters_keypoints_inputs = Seq(32.0,96.0,160.0,224.0) | ||
val stores_keypoints_inputs = Seq(0.0,3.0,6.0,9.0) | ||
val gateds_keypoints_inputs = Seq(0.0,2.0,4.0) | ||
val outerIters_keypoints_inputs = Seq(2.0,8.0,14.0) | ||
val loads_keypoints_inputs = Seq(1.0,3.0,5.0,7.0,9.0) | ||
val outerIters_keypoints_outputs = Seq(0.0,0.96400726,2.0) | ||
val gateds_keypoints_outputs = Seq(1.0631267,2.0,1.7433352) | ||
val stores_keypoints_outputs = Seq(1.3180362,1.8577967,2.0,2.0) | ||
val innerIters_keypoints_outputs = Seq(0.0,0.92205065,1.4818518,2.0) | ||
val loads_keypoints_outputs = Seq(0.0,0.48250738,1.0572726,1.5440177,2.0) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package models | ||
|
||
object DenseStoreModelData{ | ||
val params = Seq(1719.094,1443.751,1188.5411,1953.5675,1683.4391,1444.2003,2238.9692,2014.7754,1816.3032,1861.0571,1582.0151,1334.6908,2179.2366,1936.6375,1721.997,2601.6255,2524.0051,2489.403,2032.4385,1779.5958,1559.1464,2444.615,2282.862,2164.743,3007.1777,3136.053,3387.8828,1884.6885,1617.8473,1413.9296,2207.1057,1983.6978,1847.1709,2608.8718,2502.5862,2475.0615,2055.1602,1766.9823,1582.3021,2501.278,2301.7852,2257.1072,3133.297,3341.179,3830.592,2277.6794,2047.0331,1943.1749,2868.0767,2846.0483,3154.5967,3738.446,4455.593,5777.8486,2079.1929,1846.3665,1755.442,2515.4668,2409.4521,2545.2043,3055.5276,3175.4692,3533.2097,2296.1091,2057.5986,2115.9092,2919.22,2951.5002,3575.5637,3786.52,4518.213,5991.8174,2578.2444,2433.2158,2575.8533,3387.9102,3662.4443,4704.9966,4560.117,5958.072,8487.428,1452.7153,871.69824,306.48032,1693.6179,1009.9475,408.6407,2016.089,1322.6473,657.40436,1598.5544,951.605,374.29004,1955.9524,1218.3593,427.0534,2527.0808,2148.527,1649.4894,1798.9082,1174.1864,574.18054,2297.7712,1672.3607,883.09607,3115.0178,3159.4062,3373.2473,1629.7843,985.74646,493.8577,2002.1816,1300.5197,811.5159,2515.6304,1917.2313,1320.7115,1798.4254,892.2304,244.27414,2346.29,1132.8811,45.666004,3355.0208,3463.0225,4486.307,2084.902,1282.5164,740.1523,2880.649,1998.1063,2390.8884,4406.396,6394.4463,11387.591,1853.5243,1158.9678,857.3309,2422.4895,1848.7975,2152.9004,3191.451,3034.3762,3477.758,2087.3582,993.147,1148.0575,2995.9531,2168.4907,4465.119,4532.793,6423.2305,12048.127,2496.3594,1677.7827,1679.4478,3739.454,3355.081,6835.3975,5929.6357,10078.751,20101.334,1203.6404,297.54645,264.5352,1464.9464,404.97934,547.5114,1798.6582,496.59366,890.6,1371.3414,340.1867,291.37158,1786.8267,507.04782,605.54645,2497.3518,1663.8043,1280.0878,1605.7001,581.60657,558.0752,2216.9355,1014.8884,807.0497,3292.284,3200.8665,3134.8489,1431.4056,416.23746,389.3318,1902.0214,852.6881,867.61804,2518.4343,1390.6616,1450.1156,1693.2216,519.20026,719.6612,2494.3975,1255.3778,2298.86,3915.4758,5064.0483,4568.925,2059.8757,860.775,695.66284,3304.292,2916.218,3388.1987,5431.303,9409.167,7212.184,1743.63,705.00665,787.2344,2550.149,1893.0946,1946.6659,3587.9807,3632.0427,4045.74,2178.0862,982.165,1955.1687,3751.736,4865.378,6078.25,6047.698,12237.847,13927.893,2810.426,2187.2605,2081.572,5130.667,8917.913,13595.683,8260.2705,18761.982,31465.13) | ||
val innerIters_keypoints_inputs = Seq(32.0,96.0,160.0,224.0) | ||
val stores_keypoints_inputs = Seq(0.0,3.0,6.0,9.0) | ||
val gateds_keypoints_inputs = Seq(0.0,2.0,4.0) | ||
val outerIters_keypoints_inputs = Seq(2.0,8.0,14.0) | ||
val loads_keypoints_inputs = Seq(1.0,3.0,5.0,7.0,9.0) | ||
val outerIters_keypoints_outputs = Seq(0.0,1.2723361,1.9941331) | ||
val gateds_keypoints_outputs = Seq(1.9558235,2.0,1.8906076) | ||
val stores_keypoints_outputs = Seq(1.9404178,1.967513,2.0,2.0) | ||
val innerIters_keypoints_outputs = Seq(0.0,0.84398055,1.4838115,2.0) | ||
val loads_keypoints_outputs = Seq(0.0,0.6956528,1.1421059,1.6092755,2.0) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package models | ||
|
||
object GatedDenseStoreModelData{ | ||
val params = Seq(1719.094,1443.751,1188.5411,1953.5675,1683.4391,1444.2003,2238.9692,2014.7754,1816.3032,1861.0571,1582.0151,1334.6908,2179.2366,1936.6375,1721.997,2601.6255,2524.0051,2489.403,2032.4385,1779.5958,1559.1464,2444.615,2282.862,2164.743,3007.1777,3136.053,3387.8828,1884.6885,1617.8473,1413.9296,2207.1057,1983.6978,1847.1709,2608.8718,2502.5862,2475.0615,2055.1602,1766.9823,1582.3021,2501.278,2301.7852,2257.1072,3133.297,3341.179,3830.592,2277.6794,2047.0331,1943.1749,2868.0767,2846.0483,3154.5967,3738.446,4455.593,5777.8486,2079.1929,1846.3665,1755.442,2515.4668,2409.4521,2545.2043,3055.5276,3175.4692,3533.2097,2296.1091,2057.5986,2115.9092,2919.22,2951.5002,3575.5637,3786.52,4518.213,5991.8174,2578.2444,2433.2158,2575.8533,3387.9102,3662.4443,4704.9966,4560.117,5958.072,8487.428,1452.7153,871.69824,306.48032,1693.6179,1009.9475,408.6407,2016.089,1322.6473,657.40436,1598.5544,951.605,374.29004,1955.9524,1218.3593,427.0534,2527.0808,2148.527,1649.4894,1798.9082,1174.1864,574.18054,2297.7712,1672.3607,883.09607,3115.0178,3159.4062,3373.2473,1629.7843,985.74646,493.8577,2002.1816,1300.5197,811.5159,2515.6304,1917.2313,1320.7115,1798.4254,892.2304,244.27414,2346.29,1132.8811,45.666004,3355.0208,3463.0225,4486.307,2084.902,1282.5164,740.1523,2880.649,1998.1063,2390.8884,4406.396,6394.4463,11387.591,1853.5243,1158.9678,857.3309,2422.4895,1848.7975,2152.9004,3191.451,3034.3762,3477.758,2087.3582,993.147,1148.0575,2995.9531,2168.4907,4465.119,4532.793,6423.2305,12048.127,2496.3594,1677.7827,1679.4478,3739.454,3355.081,6835.3975,5929.6357,10078.751,20101.334,1203.6404,297.54645,264.5352,1464.9464,404.97934,547.5114,1798.6582,496.59366,890.6,1371.3414,340.1867,291.37158,1786.8267,507.04782,605.54645,2497.3518,1663.8043,1280.0878,1605.7001,581.60657,558.0752,2216.9355,1014.8884,807.0497,3292.284,3200.8665,3134.8489,1431.4056,416.23746,389.3318,1902.0214,852.6881,867.61804,2518.4343,1390.6616,1450.1156,1693.2216,519.20026,719.6612,2494.3975,1255.3778,2298.86,3915.4758,5064.0483,4568.925,2059.8757,860.775,695.66284,3304.292,2916.218,3388.1987,5431.303,9409.167,7212.184,1743.63,705.00665,787.2344,2550.149,1893.0946,1946.6659,3587.9807,3632.0427,4045.74,2178.0862,982.165,1955.1687,3751.736,4865.378,6078.25,6047.698,12237.847,13927.893,2810.426,2187.2605,2081.572,5130.667,8917.913,13595.683,8260.2705,18761.982,31465.13) | ||
val innerIters_keypoints_inputs = Seq(32.0,96.0,160.0,224.0) | ||
val stores_keypoints_inputs = Seq(0.0,3.0,6.0,9.0) | ||
val gateds_keypoints_inputs = Seq(0.0,2.0,4.0) | ||
val outerIters_keypoints_inputs = Seq(2.0,8.0,14.0) | ||
val loads_keypoints_inputs = Seq(1.0,3.0,5.0,7.0,9.0) | ||
val outerIters_keypoints_outputs = Seq(0.0,1.2723361,1.9941331) | ||
val gateds_keypoints_outputs = Seq(1.9558235,2.0,1.8906076) | ||
val stores_keypoints_outputs = Seq(1.9404178,1.967513,2.0,2.0) | ||
val innerIters_keypoints_outputs = Seq(0.0,0.84398055,1.4838115,2.0) | ||
val loads_keypoints_outputs = Seq(0.0,0.6956528,1.1421059,1.6092755,2.0) | ||
} |
Oops, something went wrong.