Skip to content

Commit

Permalink
CIFAR-10 loader
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Daniel Pantasdo committed Jun 14, 2018
1 parent f38a9f6 commit d0c18dc
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 23 deletions.
56 changes: 36 additions & 20 deletions cifar10/dataset/cifar_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,43 @@
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')

return dict

def get_data_and_labels():
data = []
labels = []
def get_data_and_labels(filename):
batch = unpickle(filename)

for i in range(5):
id = str(i + 1)
filename = "cifar10/dataset/data_batch_" + id
batch = unpickle(filename)
labels_key = 'labels'.encode()
labels = batch[labels_key]

labels_key = 'labels'.encode()
batch_labels = batch[labels_key]
labels.extend(batch_labels)
data_key = 'data'.encode()
raw_batch_data = batch[data_key]

data_key = 'data'.encode()
raw_batch_data = batch[data_key]
data = []
for i in range(len(raw_batch_data)):
np_data = (np.array(raw_batch_data[i], dtype='float32') / 255)
data.append(np_data)

batch_data = []
for i in range(len(raw_batch_data)):
np_data = (np.array(raw_batch_data[i], dtype='float32') / 255)
batch_data.append(np_data)
return (data, labels)

data.extend(batch_data)
def get_train_data_labels():
train_data = []
train_labels = []

return (data, labels)
for i in range(5):
id = str(i + 1)
train_filename = "cifar10/dataset/data_batch_" + id
(batch_data, batch_labels) = get_data_and_labels( train_filename )

train_data.extend(batch_data)
train_labels.extend(batch_labels)

return (train_data, train_labels)

def get_segmented_data():
(data, labels) = get_data_and_labels()
def get_test_data_labels():
return get_data_and_labels("test_batch")

def get_segmented_data(data, labels):
num_data = len(labels)
num_labels = 10

Expand All @@ -49,6 +56,15 @@ def get_segmented_data():

return segmented_data

def get_cifar_10_train_test():
(train_data_raw, train_labels_raw) = get_train_data_labels()
(test_data_raw, test_labels_raw) = get_test_data_labels()

segmented_training_data = get_segmented_data(train_data_raw, train_labels_raw)
segmented_test_data = get_segmented_data(test_data_raw, test_labels_raw)

return (segmented_training_data, segmented_test_data)

def visualize_image(segmented_data, img_key, idx):
flattened_img = segmented_data[img_key][idx]

Expand Down
13 changes: 11 additions & 2 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from timeit import default_timer as timer

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
mnist_16_filename = "mnist_16"
from cifar10.dataset import cifar_extractor

mnist_16 = "mnist_16"
cifar_10 = "cifar_10"

def segment_data(data):
segmented_data = []
Expand All @@ -32,7 +35,7 @@ def segment_data(data):
return segmented_tensor

def load_data(filename):
if filename == mnist_16_filename:
if filename == mnist_16:
print("Loading data...")
train_raw = genfromtxt('train_mnist_16.csv', delimiter=',')
test_raw = genfromtxt('test_mnist_16.csv', delimiter=',')
Expand All @@ -42,6 +45,12 @@ def load_data(filename):
test_data = segment_data(test_raw)
print("Data segmented!")

return (train_data, test_data)
if filename == cifar_10:
print("Loading data...")
(train_data, test_data) = cifar_extractor.get_cifar_10_train_test()
print("Data loaded!")

return (train_data, test_data)

return
17 changes: 16 additions & 1 deletion struct_to_spn.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,25 @@ def generate_network(self):
print("Done")
return

total_edge_count = 0
edges = []
def get_edges(level, level_type, level_nodes, edge_count):
if level_type != 'Sum':
return

global edges, total_edge_count

for node in level_nodes:
edges.extend(node.edges)

total_edge_count += edge_count

def main():
structure = MultiChannelConvSPN(8, 1, 1, 2, 2, 1)
structure = MultiChannelConvSPN(4, 1, 1, 2, 2, 1)
shared_parameters = param.Param()

structure.traverse_by_level(get_edges)

network = MatrixSPN(structure, shared_parameters, is_cuda=False)
pass

Expand Down

0 comments on commit d0c18dc

Please sign in to comment.