Skip to content

Commit

Permalink
NASNet (LLNL#2109)
Browse files Browse the repository at this point in the history
* LBANN implementation of NASNet

* Add references to NASNet and LTFB

* Code refactor for integration test, support for interactive allocation, additional code clean up

* Add integration test for NASNet
  • Loading branch information
samadejacobs authored May 23, 2022
1 parent fbc9d5c commit 900f30c
Show file tree
Hide file tree
Showing 10 changed files with 1,031 additions and 3 deletions.
6 changes: 6 additions & 0 deletions applications/nas/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Neural Architecture Search

This directory contains LBANN implementations of NAS search spaces and search strategies.
First in the series is [NASNet](https://arxiv.org/abs/1707.07012) search space with random (baseline) and [LTFB](https://lbann.readthedocs.io/en/latest/execution_algorithms/ltfb.html) search strategies.
It will eventually contain reference implementations of other search spaces and search strategies.

147 changes: 147 additions & 0 deletions applications/nas/nasnet/cifar_networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# NASNet Search Space https://arxiv.org/pdf/1707.07012.pdf
# code modified from DARTS https://github.com/quark0/darts
import numpy as np
import sys
import os
import time
from collections import namedtuple
import lbann
import lbann.models
import lbann.models.resnet
from search import micro_encoding
from os.path import join
import data.cifar10

sys.path.insert(0, os.getenv('PWD'))
import search.model as cifar


Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
Genotype_norm = namedtuple('Genotype', 'normal normal_concat')
Genotype_redu = namedtuple('Genotype', 'reduce reduce_concat')

# what you want to search should be defined here and in micro_operations
PRIMITIVES = [
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'sep_conv_7x7',
'conv_7x1_1x7',
]

def generate_genomes(pop_size,
num_blocks=5,
num_ops=7,
num_cells=2):
seed = 0
np.random.seed(seed)
B, n_ops, n_cell = num_blocks, num_ops, num_cells
networks = []
genotypes = []
network_id = 0

while len(networks) < pop_size:
bit_string = []
for c in range(n_cell):
for b in range(B):
bit_string += [np.random.randint(n_ops),
np.random.randint(b + 2),
np.random.randint(n_ops),
np.random.randint(b + 2)
]

genome = micro_encoding.convert(bit_string)
# check against evaluated networks in case of duplicates
doTrain = True
for network in networks:
if micro_encoding.compare(genome, network):
doTrain = False
break

if doTrain:
genotype = micro_encoding.decode(genome)
#print("Newtwork id, bitstring, genome, genotype ", network_id, bit_string, genome, genotype)
networks.append(genome)
genotypes.append(genotype)
network_id +=1

return genotypes


def create_networks(exp_dir,
num_epochs,
mini_batch_size,
pop_size,
use_ltfb=False,
num_blocks=5,
num_ops=7,
num_cells=2,
):
trainer_id = 0
# Setup shared data reader and optimizer
reader = data.cifar10.make_data_reader(num_classes=10)
opt = lbann.Adam(learn_rate=0.0002,beta1=0.9,beta2=0.99,eps=1e-8)
genotypes = generate_genomes(pop_size,num_blocks,num_ops,num_cells)
for g in genotypes:
mymodel = cifar.NetworkCIFAR(16, 10, 8, False, g)

images = lbann.Input(data_field='samples')
labels = lbann.Input(data_field='labels')

preds,_ = mymodel(images)
probs = lbann.Softmax(preds)
cross_entropy = lbann.CrossEntropy(probs, labels)
top1 = lbann.CategoricalAccuracy(probs, labels)

obj = lbann.ObjectiveFunction([cross_entropy])


metrics = lbann.Metric(top1, name='accuracy', unit='%')

callbacks = [lbann.CallbackPrint(),
lbann.CallbackTimer()]


model = lbann.Model(epochs=num_epochs,
layers=[images,labels],
objective_function=obj,
metrics=metrics,
callbacks=callbacks)


# Setup trainer
trainer = lbann.Trainer(mini_batch_size=mini_batch_size)

if use_ltfb:
print("Using LTFB ")
SGD = lbann.BatchedIterativeOptimizer
RPE = lbann.RandomPairwiseExchange
ES = lbann.RandomPairwiseExchange.ExchangeStrategy(strategy='checkpoint_binary')
metalearning = RPE(
metric_strategies={'accuracy': RPE.MetricStrategy.HIGHER_IS_BETTER},
exchange_strategy=ES)
ltfb = lbann.LTFB("ltfb",
metalearning=metalearning,
local_algo=SGD("local sgd", num_iterations=625),
metalearning_steps=num_epochs)

trainer = lbann.Trainer(mini_batch_size=mini_batch_size,
training_algo=ltfb)

# Export Protobuf file
lbann.proto.save_prototext(
os.path.join(exp_dir, f'experiment.prototext.trainer{trainer_id}'),
model=model,
optimizer=opt,
data_reader=reader,
trainer=trainer)

trainer_id +=1

return trainer, model, reader, opt


99 changes: 99 additions & 0 deletions applications/nas/nasnet/cifar_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import numpy as np
import sys
import os
import time
import lbann
import argparse
import lbann.contrib.args
import lbann.contrib.launcher
from os.path import join
import subprocess

sys.path.insert(0, os.getenv('PWD'))
import cifar_networks


# ----------------------------------
# Command-line arguments
# ----------------------------------

desc = ('Micro search on CIFAR10 data using LBANN.')
parser = argparse.ArgumentParser(description=desc)
#lbann.contrib.args.add_scheduler_arguments(parser)

#NAS parameters
parser.add_argument(
'--num-blocks', action='store', default=5, type=int,
help='Number of blocks per cell (default: 5)')
parser.add_argument(
'--n-ops', action='store', default=7, type=int,
help='Number of operations (default: 7)')
parser.add_argument(
'--n-cell', action='store', default=2, type=int,
help='Number of cells (default: 2)')

parser.add_argument(
'--use-ltfb', action='store_true', help='Use LTFB')

#Training (hyper) parameters
parser.add_argument(
'--mini-batch-size', action='store', default=64, type=int,
help='mini-batch size (default: 64)', metavar='NUM')
parser.add_argument(
'--num-epochs', action='store', default=20, type=int,
help='number of epochs (default: 20)', metavar='NUM')

#Compute (job) parameters
parser.add_argument(
'--nodes', action='store', default=4, type=int,
help='Num of compute nodes (default: 4)')
parser.add_argument(
'--ppn', action='store', default=2, type=int,
help='Processes per node (default: 2)')
parser.add_argument("--ppt", type=int, default=2)
parser.add_argument(
'--job-name', action='store', default='denas_cifar10', type=str,
help='scheduler job name (default: denas_cifar10)')

parser.add_argument(
'--exp-dir', action='store', default='exp_cifar10', type=str,
help='exp dir (default: exp_cifar10)')
lbann.contrib.args.add_optimizer_arguments(parser, default_learning_rate=0.1)
args = parser.parse_args()



if __name__ == "__main__":
tag = 'ltfb' if args.use_ltfb else 'random'
expd = 'search-{}-{}-{}'.format('nasnet-micro-cifar10', tag, time.strftime("%Y%m%d-%H%M%S"))
if not os.path.exists(expd):
os.mkdir(expd)
print('Experiment dir : {}'.format(expd))

script = lbann.launcher.make_batch_script(nodes=args.nodes,
procs_per_node=args.ppn,
experiment_dir=expd)
pop_size = int(args.nodes*args.ppn/args.ppt)

cifar_networks.create_networks(expd,
args.num_epochs,
args.mini_batch_size,
pop_size,
use_ltfb=args.use_ltfb,
num_blocks=args.num_blocks,
num_ops=args.n_ops,
num_cells=args.n_cell)

proto_file = os.path.join(script.work_dir,'experiment.prototext.trainer0')
command = [
lbann.lbann_exe(),
f'--procs_per_trainer={args.ppt}',
'--generate_multi_proto',
f'--prototext={proto_file}']
script.add_parallel_command(command)

# Run script
script.run(True)



28 changes: 28 additions & 0 deletions applications/nas/nasnet/data/cifar10/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import os.path

import google.protobuf.text_format
import lbann
import lbann.contrib.lc.paths

def make_data_reader(num_classes=10):

# Load Protobuf message from file
current_dir = os.path.dirname(os.path.realpath(__file__))
protobuf_file = os.path.join(current_dir, 'data_reader.prototext')
message = lbann.lbann_pb2.LbannPB()
with open(protobuf_file, 'r') as f:
google.protobuf.text_format.Merge(f.read(), message)
message = message.data_reader

# Check if data paths are accessible
data_dir = lbann.contrib.lc.paths.cifar10_dir()

if not os.path.isdir(data_dir):
raise FileNotFoundError('could not access {}'.format(data_dir))

# Set paths
message.reader[0].data_filedir = data_dir
message.reader[1].data_filedir = data_dir

return message
44 changes: 44 additions & 0 deletions applications/nas/nasnet/data/cifar10/data_reader.prototext
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
data_reader {
reader {
name: "cifar10"
role: "train"
shuffle: true
data_filedir: "path/to/cifar10/data"
validation_percent: 0.1
tournament_percent: 0.1
absolute_sample_count: 0
percent_of_data_to_use: 1.0

transforms {
horizontal_flip {
p: 0.5
}
}
transforms {
normalize_to_lbann_layout {
means: "0.44653 0.48216 0.4914"
stddevs: "0.26159 0.24349 0.24703"
}
}
}
reader {
name: "cifar10"
role: "test"
shuffle: true
data_filedir: "path/to/cifar10/data"
absolute_sample_count: 0
percent_of_data_to_use: 1.0

transforms {
horizontal_flip {
p: 0.5
}
}
transforms {
normalize_to_lbann_layout {
means: "0.44653 0.48216 0.4914"
stddevs: "0.26159 0.24349 0.24703"
}
}
}
}
Loading

0 comments on commit 900f30c

Please sign in to comment.