Skip to content

Latest commit

 

History

History

toy_example

Continual Learning - Toy Problems

In this folder we investigate our ideas on toy problems to gain a better understanding.

Note, we recommend running the scripts in an ipython console with inline plotting being enabled, as a lot of plots are generated to visualize training progress.

For instance, when using ipython qtconsole one can start the default training by executing (cpu-only computation is usually faster for this toy example):

>>> %matplotlib inline
>>> %run train.py --no_cuda

Note, in general one needs to distinguish between tasks with overlapping input domains (e.g., different 1D functions with the same input domain x) and tasks with (practically) disjoint input domains (e.g., the inputs x from different tasks are drawn from narrow Gaussians with far apart means). The former set of tasks can only be solved by a system that can be conditioned on the current task (as long as the system only has 1 output head for all tasks), because the task to be solved cannot be inferred from the input. The latter type of tasks can also be solved by a system that infers the task identity from the input.

General notes

  • Usually, one can explore the options of a script by typing python3 SCRIPT_NAME --help.
  • The script train can be used to train tasks in a continual learning setting. The script provides several different ways of training.
  • The dataset to be used (for a description of each dataset see below) can be selected via the command line option --dataset.
  • The set of tasks associated with a dataset cannot be configured via the command line. To provide the largest degree in flexibility we ask you to refer to the task definition in the function _generate_tasks of the module train_utils.
  • A typical baseline for continual learning is multitask learning, where all the data from all tasks is available at all times during training. We provide several options to train such a baseline in the script train_multitask.
  • If automatic task recognition should be used during inference (rather than manual selecting the correct task embedding for the hypernetwork), please refer to the option --use_task_detection.
  • The file train_ewc can be used to train a main network in a continual learning setup via EWC (or online EWC). Note, when using tasks with overlapping input domains, you can use a multi-head setting (option --multi_head). You can also train a hypernetwork with the EWC regularizer. Please refer to the option --reg of the train script.

1D-function regression

In this problem set (defined in the module regression1d_data) each task describes a 1D function. Hence, the goal is to learn a predefined function y = f(x) on a predefined input domain.

Examples

3 polynomials

In this example, we defined the following set of tasks in the method _generate_1d_tasks of module train_utils:

map_funcs = [lambda x : (x+3.), 
             lambda x : 2. * np.power(x, 2) - 1,
             lambda x : np.power(x-3., 3)]
num_tasks = len(map_funcs)
x_domains = [[-4,-2], [-1,1], [2,4]]
std = .05

Note, we have chosen a set of tasks with distinct input domains as this is the type of benchmarks typically explored in the CL literature. Note, that our method is agnostic to this setting as we can always condition on the task directly or automatically extract the task from context information (that can be more informative than pure network inputs).

Example results with these tasks in a continual learning setting with our approach can be obtained via:

$ python3 train.py --no_cuda --beta=0.005 --emb_size=2 --n_iter=4001 --lr_hyper=1e-2 --data_random_seed=42

The final MSE values should correspond to 0.00419, 0.00239, 0.00607. Note, this setting would correspond to CL scenario CL1 (see paper).

Fine-tuning

A fine-tuning plot can be generated by simply disabling the regularizer (--beta=0) in the above call.

Training from scratch

From scratch training refers to the reinitialization of main net / hypernetwork after each task (hence, it differs from fine-tuning as no transfer can occur).

$ python3 train.py --no_cuda --beta=0 --emb_size=2 --n_iter=4001 --lr_hyper=1e-2 --data_random_seed=42 --train_from_scratch
Automatic task recognition

Automatic task recognition refers to the addition of a third system that automatically detects which task embedding to choose from the context (which is always the current input in our case).

$ python3 train.py --no_cuda --beta=0.005 --emb_size=2 --n_iter=4001 --lr_hyper=1e-2 --data_random_seed=42 --use_task_detection --ae_beta_ce=20.

Note, main-/hypernet are trained independently of the task-recognition model, which is essentially a VAE. Only a simple task-recognition model is implemented (e.g., no growing heads) as task recognition from inputs alone is not the goal of these experiments. For a complete implementation of a task-recognition model as described in the paper, please refer to our classification experiments.

Multi-task learning

Multi-task learning refers to learn on all tasks at once. We provide several options to perform multi-task learning, see option --method of the script train_multitask. Method 0 refers to learning a main network only (note, the option --emb_size does not have to be specified for --method=0).

$ python3 train_multitask.py --no_cuda --emb_size=2 --n_iter=4001 --lr_hyper=5e-2 --data_random_seed=42 --method=0

2D Mixture of Gaussians

This is another regression problem (defined in the module gaussian_mixture_data), where the input x of each task is drawn from a 2D Gaussian bump, such that the inputs of all tasks combined are drawn from a mixture of Gaussians. Therefore, this problem set allows one to easily create tasks where the task-identity can easily be inferred from the network input alone, if the variances of each bump are small. Note, even though the domain of each task would be the whole real plane, the train and test data (as sampled from a Gaussian) would cluster around the mean and therefore defines regions of importance where the performance of the regression is measured.

We haven't explored this dataset yet. Why is it interesting? Imagine all regression tasks are equal just shifted in the input domain. Then, the embedding space would only need to encode the input domain shift, whereas the hypernet encodes the function. Such things might be easier to visualize in a 2D rather than a 1D space.