Skip to content

Commit

Permalink
Clean
Browse files Browse the repository at this point in the history
  • Loading branch information
Muennighoff committed Aug 2, 2023
1 parent 5de164d commit 4c01aec
Show file tree
Hide file tree
Showing 22 changed files with 12,194 additions and 54 deletions.
41 changes: 9 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# OctoPack: Instruction Tuning Code Large Language Models

![](banner.png)
![](visuals/banner.png)

This repository provides an overview of all components from the paper [OctoPack: Instruction Tuning Code Large Language Models](TODO).

Expand All @@ -17,9 +17,8 @@ This repository provides an overview of all components from the paper [OctoPack:
- [Creation](#creation)
- [Training](#training)
- [Transformers](#transformers)
- [OctoGeeX](#octogeex)
- [Megatron-LM](#megatron-lm)
- [Checkpoint conversion](#checkpoint-conversion)
- [Other](#other)
- [Citation](#citation)

<!-- /TOC -->
Expand Down Expand Up @@ -166,6 +165,7 @@ accelerate launch main.py \
--max_length_generation 2048 \
--precision bf16
```
- Unfortunately, there is some randomness depending on the Python version you use for evaluation and the `batch_size`. We use `batch_size=5` and Python 3.9.13
- We provide the exact scripts we used in `evaluation/run/eval_scripts` for each model. There is also a `_range.sh` script for each task (e.g. `evaluation/run/eval_scripts/eval_humanevalfix_range.sh`), which runs each sample individually. This is much faster if you have multiple GPUs available. In the `_range.sh` scripts you need to specify the model and language you would like to run. After running it, you will have 164 generation files, which you need to merge with `python evaluation/run/merge_generations.py "generations_*json"`. Subsequently, you need to run the evaluation as explained in the next step.

3. **Evaluate:** If you have only created generations without evaluating them (e.g. by adding the `--generation_only` flag or using `_range.sh` scripts), you can use the notebook at `evaluation/run/humanevalpack_evaluation` or [this colab](https://colab.research.google.com/drive/1tlpGcDPdKKMDqDS0Ihwh2vR_MGlzAPC_?usp=sharing) to evaluate the generations. It contains a section for each programming lanuage where it installs the language first and then given the path to your generations evaluates them providing you with the pass@k scores.
Expand All @@ -182,7 +182,11 @@ To create HumanEvalPack, we follow these steps:

### Transformers

TODO: Integrate QL's repo
The finetuning script to create OctoCoder is at `finetuning/finetune.py`. The folder contains a `README.md` with instructions.

### OctoGeeX



### Megatron-LM

Expand All @@ -196,35 +200,8 @@ We did not end up using Megatron-LM fine-tuning for the model in the paper, but
6. Create two files `train_data_paths.txt.tmp` and `valid_data_paths.txt.tmp` that contain the paths to the above created tokenized dataset. For example they could look like `"train: 1.0 0:0.95 output_prefix"` and `"valid: 1.0 0.95:1.0 output_prefix`. In this case the dataset is split into 95% training and 5% validation. The first number is the weight of the dataset, the second number is the start of the dataset and the third number is the end of the dataset.
7. Rename the checkpoint downloaded to `release` i.e. `mv starcoderbase-megatron/iter* starcoderbase-megatron/release` and create a file `starcoderbase-megatron/latest_checkpointed_iteration.txt` that contains simply `release` (`echo release > starcoderbase-megatron/latest_checkpointed_iteration.txt`).
8. Modify `training/finetune_starcoderbase.sh` to adapt `CHECKPOINT_PATH` to point to the downloaded Megatron-LM checkpoint, `WEIGHTS_TRAIN` & `WEIGHTS_VALID` to point to the above created txt files, `TOKENIZER_FILE` to StarCoder's `tokenizer.json`, point to your environment and cache locations, and modify the SBATCH settings to suit your setup. Then run it with `bash training/finetune_starcoderbase.sh`. You can interrupt and resume training, however, if you resume, you need to remove `--no_load_optim` and `--no_load_rng` from the command line arguments in the script to load the optimizer and random number generator state from the newly saved checkpoint (we only do not want to load them from starcoderbase).
9. Convert the saved checkpoint using the instructions below.

#### Checkpoint conversion
9. Convert the saved checkpoint using the script at `convert_large.sh`. It contains instructions which repos to download.

1. Update the paths in `convert_large.sh` & download the marked repos & run it

#### Other

```python
# pip install -q transformers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "/gpfsscratch/rech/ajs/commun/Bigcode-large-megatron_conv/base/shard"
checkpoint = "/gpfsscratch/rech/ajs/commun/Bigcode-large-megatron_conv/base3/"
device = "cuda" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(
checkpoint,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
).to(device)

inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=1)
print(tokenizer.decode(outputs[0]))
```

## Citation

Expand Down
74 changes: 74 additions & 0 deletions dataset/oasst/add_commitpackft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import json
import sys
import random

# {"commit":"b4829b41402321a0a6f0f3c02766a5b25ebd09a1","old_file":"test\/clang-tidy\/hicpp-exception-baseclass.cpp","new_file":"test\/clang-tidy\/hicpp-exception-baseclass.cpp","old_contents":"\/\/ RUN: %check_clang_tidy %s hicpp-exception-baseclass %t\n\nnamespace std {\nclass exception {};\n} \/\/ namespace std\n\nclass derived_exception : public std::exception {};\nclass non_derived_exception {};\n\nvoid problematic() {\n try {\n throw int(42); \/\/ Built in is not allowed\n\/\/ CHECK-MESSAGES: [[@LINE-1]]:5: warning: throwing an exception whose type is not derived from 'std::exception'\n } catch (int e) {\n }\n throw int(42); \/\/ Bad\n\/\/ CHECK-MESSAGES: [[@LINE-1]]:3: warning: throwing an exception whose type is not derived from 'std::exception'\n\n try {\n throw non_derived_exception(); \/\/ Some class is not allowed\n\/\/ CHECK-MESSAGES: [[@LINE-1]]:5: warning: throwing an exception whose type is not derived from 'std::exception'\n\/\/ CHECK-MESSAGES: 8:1: note: type defined here\n } catch (non_derived_exception &e) { \n }\n throw non_derived_exception(); \/\/ Bad\n\/\/ CHECK-MESSAGES: [[@LINE-1]]:3: warning: throwing an exception whose type is not derived from 'std::exception'\n\/\/ CHECK-MESSAGES: 8:1: note: type defined here\n}\n\nvoid allowed_throws() {\n try {\n throw std::exception(); \/\/ Ok\n } catch (std::exception &e) { \/\/ Ok\n }\n throw std::exception();\n\n try {\n throw derived_exception(); \/\/ Ok\n } catch (derived_exception &e) { \/\/ Ok\n }\n throw derived_exception(); \/\/ Ok\n}\n","new_contents":"\/\/ RUN: %check_clang_tidy %s hicpp-exception-baseclass %t -- -- -fcxx-exceptions\n\nnamespace std {\nclass exception {};\n} \/\/ namespace std\n\nclass derived_exception : public std::exception {};\nclass non_derived_exception {};\n\nvoid problematic() {\n try {\n throw int(42); \/\/ Built in is not allowed\n\/\/ CHECK-MESSAGES: [[@LINE-1]]:5: warning: throwing an exception whose type is not derived from 'std::exception'\n } catch (int e) {\n }\n throw int(42); \/\/ Bad\n\/\/ CHECK-MESSAGES: [[@LINE-1]]:3: warning: throwing an exception whose type is not derived from 'std::exception'\n\n try {\n throw non_derived_exception(); \/\/ Some class is not allowed\n\/\/ CHECK-MESSAGES: [[@LINE-1]]:5: warning: throwing an exception whose type is not derived from 'std::exception'\n\/\/ CHECK-MESSAGES: 8:1: note: type defined here\n } catch (non_derived_exception &e) {\n }\n throw non_derived_exception(); \/\/ Bad\n\/\/ CHECK-MESSAGES: [[@LINE-1]]:3: warning: throwing an exception whose type is not derived from 'std::exception'\n\/\/ CHECK-MESSAGES: 8:1: note: type defined here\n}\n\nvoid allowed_throws() {\n try {\n throw std::exception(); \/\/ Ok\n } catch (std::exception &e) { \/\/ Ok\n }\n throw std::exception();\n\n try {\n throw derived_exception(); \/\/ Ok\n } catch (derived_exception &e) { \/\/ Ok\n }\n throw derived_exception(); \/\/ Ok\n}\n","subject":"Enable exceptions for this test case to speculatively fix the build bots.","message":"Enable exceptions for this test case to speculatively fix the build bots.\n\nHopefully corrects: http:\/\/lab.llvm.org:8011\/builders\/llvm-clang-lld-x86_64-scei-ps4-ubuntu-fast\/builds\/15666\n\ngit-svn-id: a34e9779ed74578ad5922b3306b3d80a0c825546@310732 91177308-0d34-0410-b5e6-96231b3b80d8\n","lang":"C++","license":"apache-2.0","repos":"llvm-mirror\/clang-tools-extra,llvm-mirror\/clang-tools-extra,llvm-mirror\/clang-tools-extra,llvm-mirror\/clang-tools-extra"}
path = sys.argv[1]
NUM_SELECT = 100

random.seed(42)

# Manually chosen lines
### CPP ### # 70 samples
INDICES = [3, 10, 17, 19, 21, 23, 26, 27, 28, 35, 36, 38, 41, 912, 1791, 4139, 4597, 4598, 4647, 4791, 876, 4868, 2231, 2308, 2162, 4140, 4594, 1221, 1213, 3520, 326, 2044, 842, 2897, 1822, 3264, 2689, 3274, 226, 4189, 947, 1557, 4410, 2983, 573, 1020, 2460, 3105, 1425, 4488, 2350, 3892, 3730, 3395, 1592, 3145, 4049, 1208, 1792, 2025, 4270, 2249, 2225, 526, 3398, 3339, 4445, 3816, 1694, 3432]
IDX_OFFSET = 493
#"""
### PYTHON ### # 89 samples
#INDICES = [7296, 48598, 16049, 9144, 48540, 35741, 5697, 27651, 1739, 36781, 13031, 35713, 27493, 38618, 53046, 425, 49729, 14110, 50036, 22059, 24898, 17335, 30108, 35142, 24807, 41198, 37837, 43336, 50663, 5229, 18217, 23909, 13730, 44796, 39920, 41613, 16043, 14392, 44866, 21252, 32717, 25928, 9363, 9150, 48823, 36789, 17219, 48956, 38242, 27666, 39086, 39052, 30674, 22293, 10365, 33271, 24504, 34757, 32021, 5613, 47966, 49497, 26148, 29588, 14725, 1378, 38564, 47780, 44129, 42824, 42348, 3972, 26386, 22236, 16295, 27648, 18254, 16371, 4940, 29041, 52954, 15491, 26282, 10789, 141, 14267, 35533, 49019, 4453]
#IDX_OFFSET = 394

### RUST ### # 75 samples
INDICES = [2619, 456, 1003, 419, 2771, 2233, 2418, 130, 952, 2069, 108, 2298, 1718, 1839, 1139, 418, 1470, 322, 2533, 1093, 2495, 1002, 669, 899, 2804, 938, 1643, 1620, 2989, 1010, 1076, 1729, 1576, 1917, 2167, 2994, 1858, 1078, 435, 1222, 2617, 2973, 1486, 237, 986, 2941, 350, 1990, 525, 2702, 2251, 676, 2484, 823, 1276, 1634, 2751, 1849, 495, 1015, 942, 937, 1140, 397, 1765, 1912, 221, 2676, 248, 1018, 2944, 2983, 2755, 1023, 2276]
IDX_OFFSET = 306

### GO ### # 72 samples
#INDICES = [204, 2253, 2006, 1828, 4467, 3456, 1791, 1905, 4931, 3436, 3679, 2278, 53, 1307, 3462, 1763, 2757, 2817, 4945, 3763, 1022, 3100, 2401, 2962, 1575, 375, 653, 3113, 2277, 3108, 2211, 4562, 1876, 2584, 542, 4646, 2577, 4998, 2020, 4598, 2020, 4780, 3271, 744, 898, 26, 871, 2444, 1629, 4889, 3063, 1323, 4418, 4344, 159, 2519, 4503, 552, 580, 1949, 1083, 1990, 2902, 3469, 4393, 3675, 4993, 3789, 3630, 3329, 2172, 4552]
#IDX_OFFSET = 263

#"""
### JAVA ### # 86 samples
#INDICES = [204, 2253, 2006, 1828, 4467, 3456, 1791, 1905, 4931, 3436, 3679, 2278, 53, 1307, 3462, 1763, 2757, 2817, 4945, 3763, 1022, 3100, 2401, 2962, 1575, 375, 653, 3113, 2277, 3108, 2211, 4562, 1876, 2584, 542, 4646, 2577, 4998, 2020, 4598, 2020, 4780, 3271, 744, 898, 26, 871, 2444, 1629, 4889, 3063, 1323, 4418, 4344, 159, 2519, 4503, 552, 580, 1949, 1083, 1990, 2902, 3469, 4393, 3675, 4993, 3789, 3630, 3329, 2172, 4552, 12418, 9347, 13861, 10276, 17403, 5158, 2245, 1302, 10366, 12969, 10360, 15017, 20368, 9912]
#IDX_OFFSET = 341
#"""
### JAVASCRIPT ### # 61 samples
#INDICES = [14628, 6717, 44348, 35741, 27651, 6140, 14328, 33118, 1739, 46925, 45962, 35713, 14446, 52810, 45753, 22298, 14110, 50036, 24898, 17335, 2847, 24807, 36178, 19213, 23700, 4558, 3003, 29714, 24260, 17496, 16043, 2103, 4337, 37170, 13934, 42954, 42129, 30071, 9363, 9150, 48823, 36789, 38311, 28077, 23723, 10484, 51909, 44596, 25009, 30674, 34676, 44583, 7507, 35190, 49209, 50370, 42006, 7310, 10365, 212, 47323,]
#IDX_OFFSET = 187
#"""
ADD = True

if ADD:
with open(path, "r") as f:
lines = f.readlines()
lines_chosen = [lines[i] for i in INDICES]
with open(f"oasstcommitpackftmanual.jsonl", "a") as f:
for l in lines_chosen:
l = json.loads(l)
data = {
"prompt": l["subject"] + "\n" + l["old_contents"],
"completion": l["new_contents"],
}
f.write(json.dumps(data, ensure_ascii=False) + "\n")
exit()


with open(path, "r") as f:
# c = [json.loads(l) for l in f.readlines()]
# Collect indices that are OK
lines = f.readlines()
lines_shuffled = random.sample(lines, len(lines))
for j, l in enumerate(lines_shuffled):
if j < IDX_OFFSET: continue
i = lines.index(l)
data = json.loads(l)
print(data["subject"])
is_ok = input("Is this OK? [y/n]") # n to break out; just enter for no
if is_ok == "y":
INDICES.append(i)
elif is_ok == "n":
print("Breaking at index {}".format(j))
break

print(INDICES)
print("Samples: ", len(INDICES))


Loading

0 comments on commit 4c01aec

Please sign in to comment.