Skip to content

Commit

Permalink
Merge pull request #24 from HugAILab/nchen-pr
Browse files Browse the repository at this point in the history
  • Loading branch information
nchen909 authored May 7, 2023
2 parents d80b4ca + 1603d4e commit 1f4445f
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 54 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ logs/
outputs/
**/__pycache__/
.history
nohup.out
8 changes: 7 additions & 1 deletion applications/benchmark/codexglue/run_codexglue_devign.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
#### pre-trained lm path
###
# # -*- coding: utf-8 -*-
# @Author: nchen909 NuoChen
# @Date: 2023-05-06 21:32:46
# @FilePath: /HugNLP/applications/benchmark/codexglue/run_codexglue_devign.sh
###
path=/root/autodl-tmp/CodePrompt/data/huggingface_models/codebert-base/
MODEL_TYPE=codebert

Expand All @@ -18,7 +24,7 @@ lr=1e-05


export CUDA_VISIBLE_DEVICES=0,1
python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=22025 hugnlp_runner.py \
python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=12383 hugnlp_runner.py \
--model_name_or_path=$path \
--data_dir=$data_path \
--output_dir=./outputs/codexglue/$codexglue_task \
Expand Down
151 changes: 151 additions & 0 deletions applications/code/HugClone/clone_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
'''
# -*- coding: utf-8 -*-
Author: nchen909 NuoChen
Date: 2023-05-07 16:59:19
FilePath: /HugNLP/applications/code/HugClone/clone_api.py
'''
import sys
sys.path.append("./")
sys.path.append("../")
sys.path.append("../../")
import os
from processors.code.code_clone.data_processor import CodeCloneProcessor
from models import CODE_MODEL_CLASSES
from models import TOKENIZER_CLASSES
import torch
from torch import nn

class HugCloneAPI:
def __init__(self, model_type, hugcode_model_name_or_path) -> None:
if model_type not in CODE_MODEL_CLASSES["code_cls"].keys():
raise KeyError(
"You must choose one of the following model: {}".format(
", ".join(
list(CODE_MODEL_CLASSES["code_cls"].
keys()))))
self.model_type = model_type
self.config =CODE_MODEL_CLASSES["code_cls"][self.model_type].from_pretrained(hugcode_model_name_or_path)
self.tokenizer = TOKENIZER_CLASSES[self.model_type].from_pretrained(
hugcode_model_name_or_path)
self.model = CODE_MODEL_CLASSES["code_cls"][
self.model_type](self.config).from_pretrained(hugcode_model_name_or_path)
self.max_source_length = 512
self.max_target_length = 512

def request(self, func1: str, func2: str):
examples = [{'label':'0','func1':func1,'func2':func2,'id':0}]
processor = CodeCloneProcessor()
preprocess_function = processor.build_preprocess_function()
inputs= examples.map(
preprocess_function,
batched=True,
desc="tokenize examples",
)
collator = processor.get_data_collator()
batch_input=collator(inputs)
# batch_input = {
# "input_ids": inputs["input_ids"],
# "attention_mask": inputs["attention_mask"],
# }
outputs = self.model(**batch_input)
predictions, topk_result = processor.get_predict_result(outputs['logits'],examples, "test")
clone_probability = predictions['prob']
return clone_probability

if __name__ == "__main__":
from applications.code.HugClone.clone_api import HugCloneAPI
model_type = "plbart"
hugclone_model_name_or_path = "/code/cn/HugAILab/HugNLP/outputs/code/clone/codebert-base/checkpoint-27300/"
hugclone = HugCloneAPI(model_type, hugclone_model_name_or_path)

## JAVA code clone detection
func1="""
public String getData(DefaultHttpClient httpclient) {
try {
HttpGet get = new HttpGet("http://3dforandroid.appspot.com/api/v1/note");
get.setHeader("Content-Type", "application/json");
get.setHeader("Accept", "*/*");
HttpResponse response = httpclient.execute(get);
HttpEntity entity = response.getEntity();
InputStream instream = entity.getContent();
responseMessage = read(instream);
if (instream != null) instream.close();
} catch (ClientProtocolException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return responseMessage;
}
"""
func2="""
public static void copyFile(File in, File out) throws Exception {
FileChannel sourceChannel = new FileInputStream(in).getChannel();
FileChannel destinationChannel = new FileOutputStream(out).getChannel();
sourceChannel.transferTo(0, sourceChannel.size(), destinationChannel);
sourceChannel.close();
destinationChannel.close();
}
"""
clone_probability = hugclone.request(func1, func2)
print("clone_probability:{}".format(clone_probability))
print("\n\n")

## JAVA code clone detection
func1="""
public static void copyFile(File source, File dest) throws IOException {
FileChannel in = null, out = null;
try {
in = new FileInputStream(source).getChannel();
out = new FileOutputStream(dest).getChannel();
in.transferTo(0, in.size(), out);
} catch (FileNotFoundException fnfe) {
Log.debug(fnfe);
} finally {
if (in != null) in.close();
if (out != null) out.close();
}
}
"""

func2="""
public static void copyFile(File from, File to) throws IOException {
if (from.isDirectory()) {
if (!to.exists()) {
to.mkdir();
}
File[] children = from.listFiles();
for (int i = 0; i < children.length; i++) {
if (children[i].getName().equals(".") || children[i].getName().equals("..")) {
continue;
}
if (children[i].isDirectory()) {
File f = new File(to, children[i].getName());
copyFile(children[i], f);
} else {
copyFile(children[i], to);
}
}
} else if (from.isFile() && (to.isDirectory() || to.isFile())) {
if (to.isDirectory()) {
to = new File(to, from.getName());
}
FileInputStream in = new FileInputStream(from);
FileOutputStream out = new FileOutputStream(to);
byte[] buf = new byte[32678];
int read;
while ((read = in.read(buf)) > -1) {
out.write(buf, 0, read);
}
closeStream(in);
closeStream(out);
}
}
"""
clone_probability = hugclone.request(func1, func2)
print("clone_probability:{}".format(clone_probability))

"""
clone_probability:2.0006775685033062e-06
clone_probability:0.9999953508377075
"""
59 changes: 59 additions & 0 deletions applications/code/HugClone/run_clone_unified.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#### pre-trained lm path
###
# # -*- coding: utf-8 -*-
# @Author: nchen909 NuoChen
# @Date: 2023-05-07 16:59:40
# @FilePath: /HugNLP/applications/code/HugClone/run_clone_unified.sh
###
path=/root/autodl-tmp/CodePrompt/data/huggingface_models/plbart-base/
MODEL_TYPE=plbart

#### task data path (use should change this path)
data_path=/root/autodl-tmp/HugNLP/datasets/data_example/clone/

TASK_TYPE=code_cls
# TASK_TYPE=masked_prompt_prefix_cls

len=196
bz=4 # 8
epoch=10
eval_step=50
wr_step=10
lr=1e-05


export CUDA_VISIBLE_DEVICES=0,1
python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=6014 hugnlp_runner.py \
--model_name_or_path=$path \
--data_dir=$data_path \
--output_dir=./outputs/code/clone_classification_plbart\
--seed=42 \
--exp_name=default-cls \
--max_seq_length=$len \
--max_eval_seq_length=$len \
--do_train \
--do_eval \
--do_predict \
--per_device_train_batch_size=$bz \
--per_device_eval_batch_size=4 \
--gradient_accumulation_steps=1 \
--evaluation_strategy=steps \
--learning_rate=$lr \
--num_train_epochs=$epoch \
--logging_steps=100000000 \
--eval_steps=$eval_step \
--save_steps=$eval_step \
--save_total_limit=1 \
--warmup_steps=$wr_step \
--load_best_model_at_end \
--report_to=none \
--task_name=code_clone \
--task_type=$TASK_TYPE \
--model_type=$MODEL_TYPE \
--metric_for_best_model=acc \
--pad_to_max_length=True \
--remove_unused_columns=False \
--overwrite_output_dir \
--label_names=labels \
--keep_predict_labels \
--user_defined="label_names=0,1" \
48 changes: 0 additions & 48 deletions nohup.out

This file was deleted.

8 changes: 4 additions & 4 deletions processors/benchmark/codexglue/codexglue_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Date: 2023-05-06 16:16:16
FilePath: /HugNLP/processors/benchmark/codexglue/codexglue_processor.py
'''
"""Dataset utils for different data settings for GLUE."""
"""Dataset utils for different data settings for CodeXGLUE."""

import os
import copy
Expand All @@ -18,7 +18,7 @@
import random
import transformers
from processors.benchmark.codexglue.utils import DataProcessor, InputExample, InputFeatures, DefectExample, CloneExample
from transformers.data.processors.glue import *
# from transformers.data.processors.glue import *
from transformers.data.metrics import acc_and_f1
import dataclasses
from dataclasses import dataclass, asdict
Expand Down Expand Up @@ -61,7 +61,7 @@ def get_test_examples(self, data_dir):

def get_labels(self):
"""See base class."""
return ["0", "1"]
return [0,1]

def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
Expand Down Expand Up @@ -111,7 +111,7 @@ def get_test_examples(self, data_dir):

def get_labels(self):
"""See base class."""
return ["0", "1"]
return [0,1]

def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
Expand Down
2 changes: 1 addition & 1 deletion processors/benchmark/codexglue/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self,
self.raw_datasets = load_dataset("nchen909/devign-processed")
if self.data_name=='bcb':
self.raw_datasets = load_dataset("nchen909/bigclonebench-processed")
self.labels = self.raw_datasets["train"]["label"]#self.raw_datasets["train"].features["label"].names
self.labels = [0,1]#self.raw_datasets["train"].features["label"].names
self.sentence1_key, self.sentence2_key = task_to_keys[self.data_name]

def get_data_collator(self):
Expand Down

0 comments on commit 1f4445f

Please sign in to comment.