Skip to content

Commit

Permalink
Clean up use of area/probability variants
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Nyhus committed Feb 16, 2024
1 parent 9400a01 commit 16680bf
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 23 deletions.
2 changes: 1 addition & 1 deletion confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def main():
args.feature,
progress,
model.get(args.model,
feature.model_output_type(args.feature),
feature.result_type(args.feature),
args.load_model),
model_version,
args.batch_size,
Expand Down
57 changes: 42 additions & 15 deletions database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import sqlite3

import feature


log = logging.getLogger('database')

Expand Down Expand Up @@ -416,25 +418,49 @@ def mark_checked(cursor, z, x, y):


def training_tiles(cursor, feature_name):
cursor.execute('''select tile_hash, score, 0
from training_set
natural join tile_positions
natural join true_score
where feature_name = ?
''',
[feature_name])
if feature.result_type(feature_name) == 'probability':
cursor.execute('''select tile_hash, has_feature, 0
from training_set
natural join tile_positions
natural join has_feature
where feature_name = ?
''',
[feature_name])
elif feature.result_type(feature_name) == 'area':
cursor.execute('''select tile_hash, score, 0
from training_set
natural join tile_positions
natural join true_score
where feature_name = ?
''',
[feature_name])
else:
raise RuntimeError

return cursor.fetchall()


def validation_tiles(cursor, feature_name):
cursor.execute('''select tile_hash, true_score.score, scores.score
from validation_set
natural join tile_positions
natural left join true_score
natural left join scores
where feature_name = ?
''',
[feature_name])
if feature.result_type(feature_name) == 'probability':
cursor.execute('''select tile_hash, has_feature, score
from validation_set
natural join tile_positions
natural left join has_feature
natural left join scores
where feature_name = ?
''',
[feature_name])
elif feature.result_type(feature_name) == 'area':
cursor.execute('''select tile_hash, true_score.score, scores.score
from validation_set
natural join tile_positions
natural left join true_score
natural left join scores
where feature_name = ?
''',
[feature_name])
else:
raise RuntimeError
return cursor.fetchall()


Expand All @@ -451,6 +477,7 @@ def validation_tiles_for_scoring(cursor, current_model, feature_name):
[current_model, feature_name])
return cursor.fetchall()


def set_true_score(cursor, tile_hash, feature_name, true_score):
assert type(tile_hash) == str
assert type(feature_name) == str
Expand Down
2 changes: 1 addition & 1 deletion feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def overpass_query(feature):
'''


def model_output_type(feature):
def result_type(feature):
return {
'large_solar': 'probability',
'playground': 'probability',
Expand Down
6 changes: 3 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def inception_resnet_v2(input_layer, input_shape):
)


def get(model_type, output_type, weights_from=None, learning_rate=1e-4):
def get(model_type, result_type, weights_from=None, learning_rate=1e-4):
input_shape = (256, 256, 3)
inputs = keras.layers.Input(input_shape)

Expand All @@ -132,9 +132,9 @@ def get(model_type, output_type, weights_from=None, learning_rate=1e-4):
dense_1 = keras.layers.Dense(dense_width, activation='relu')(flatten)
dense_2 = keras.layers.Dense(dense_width, activation='relu')(dense_1)

if output_type == 'probability':
if result_type == 'probability':
dense_3 = keras.layers.Dense(1, activation='sigmoid')(dense_2)
elif output_type == 'area':
elif result_type == 'area':
dense_3 = keras.layers.Dense(1, activation='relu')(dense_2)

model = keras.models.Model(inputs=inputs, outputs=dense_3)
Expand Down
2 changes: 1 addition & 1 deletion score_tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def main():
image_dir = pathlib.Path(args.tile_path)

m = model.get(args.model,
feature.model_output_type(args.feature),
feature.result_type(args.feature),
args.load_model)

progress = Progress()
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tensorflow as tf

import database
import feature
import model
import util

Expand Down Expand Up @@ -183,7 +184,7 @@ def main():
validation_data = validation_data.batch(args.batch_size)

m = model.get(args.model,
feature.model_output_type(args.feature_name),
feature.result_type(args.feature),
args.load_model,
args.learning_rate)

Expand Down
1 change: 0 additions & 1 deletion web.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sqlite3

import database
import feature
import util


Expand Down

0 comments on commit 16680bf

Please sign in to comment.