Skip to content

Commit

Permalink
fixed conflict
Browse files Browse the repository at this point in the history
Signed-off-by: BAAI-OpenPlatform <[email protected]>
  • Loading branch information
BAAI-OpenPlatform committed Jan 3, 2023
2 parents 9ff7c15 + 01c6bd9 commit cae3d0d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 37 deletions.
65 changes: 39 additions & 26 deletions examples/AltCLIP/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,37 +219,42 @@ from PIL import Image
from flagai.auto_model.auto_loader import AutoLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
## 一行代码直接自动下载权重到'./checkpoints/clip-xlmr-large',并自动加载CLIP模型权重
## modelhub地址: Modelhub(https://model.baai.ac.cn/models)

loader = AutoLoader(
task_name="txt_img_matching",
model_dir="./checkpoints",
model_name="AltCLIP-XLMR-L"
model_name="AltCLIP-XLMR-L", # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
model_dir="./checkpoints"
)
## 获取加载好的模型

model = loader.get_model()
## 获取tokenizer
tokenizer = loader.get_tokenizer()
## 获取transform用来处理图像
transform = loader.get_transform()

model.eval()
model.to(device)
tokenizer = loader.get_tokenizer()

## 推理过程,图像与文本匹配
image = Image.open("./dog.jpeg")
image = transform(image)
image = torch.tensor(image["pixel_values"]).to(device)
text = tokenizer(["a rat", "a dog", "a cat"])["input_ids"]

text = torch.tensor(text).to(device)

with torch.no_grad():
image_features = model.get_image_features(image)
text_features = model.get_text_features(text)
text_probs = (image_features @ text_features.T).softmax(dim=-1)

print(text_probs.cpu().numpy()[0].tolist())
def inference():
image = Image.open("./dog.jpeg")
image = transform(image)
image = torch.tensor(image["pixel_values"]).to(device)
tokenizer_out = tokenizer(["a rat", "a dog", "a cat"],
padding=True,
truncation=True,
max_length=77,
return_tensors='pt')

text = tokenizer_out["input_ids"].to(device)
attention_mask = tokenizer_out["attention_mask"].to(device)
with torch.no_grad():
image_features = model.get_image_features(image)
text_features = model.get_text_features(text, attention_mask=attention_mask)
text_probs = (image_features @ text_features.T).softmax(dim=-1)

print(text_probs.cpu().numpy()[0].tolist())

if __name__=="__main__":
inference()
```

## CLIP微调/Finetuning
Expand All @@ -271,6 +276,7 @@ from torchvision.datasets import (
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset_root = "./clip_benchmark_datasets"
dataset_name = "cifar10"

Expand All @@ -279,7 +285,7 @@ classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'hors

auto_loader = AutoLoader(
task_name="txt_img_matching",
model_dir="./checkpoints/",
model_dir="./checkpoints",
model_name="AltCLIP-XLMR-L" # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
)

Expand All @@ -305,19 +311,26 @@ def cifar10_collate_fn(batch):
# image shape is (batch, 3, 224, 224)
images = torch.tensor([b[0]["pixel_values"][0] for b in batch])
# text_id shape is (batch, n)
input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",padding=True,truncation=True,max_length=77)["input_ids"] for b in batch])
input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",
padding=True,
truncation=True,
max_length=77)["input_ids"] for b in batch])

attention_mask = torch.tensor([tokenizer(f"a photo of a {b[1]}",
padding=True,
truncation=True,
max_length=77)["attention_mask"] for b in batch])

return {
"pixel_values": images,
"input_ids": input_ids
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if __name__ == "__main__":
trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn)
```



## 模型验证/Evaluation

我们提供了可以直接运行的验证脚本,在cifar10数据集上进行验证。
Expand Down
13 changes: 11 additions & 2 deletions examples/AltCLIP/altclip_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,20 @@ def cifar10_collate_fn(batch):
# image shape is (batch, 3, 224, 224)
images = torch.tensor([b[0]["pixel_values"][0] for b in batch])
# text_id shape is (batch, n)
input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",padding=True,truncation=True,max_length=77)["input_ids"] for b in batch])
input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",
padding=True,
truncation=True,
max_length=77)["input_ids"] for b in batch])

attention_mask = torch.tensor([tokenizer(f"a photo of a {b[1]}",
padding=True,
truncation=True,
max_length=77)["attention_mask"] for b in batch])

return {
"pixel_values": images,
"input_ids": input_ids
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if __name__ == "__main__":
Expand Down
18 changes: 11 additions & 7 deletions examples/AltCLIP/altclip_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,23 @@
tokenizer = loader.get_tokenizer()

def inference():
image = Image.open("/home/yanzhaodong/anhforth/data/images/12.png")
image = Image.open("./dog.jpeg")
image = transform(image)
image = torch.tensor(image["pixel_values"]).to(device)
text = tokenizer(["a rat", "a dog", "a cat"])["input_ids"]

text = torch.tensor(text).to(device)

tokenizer_out = tokenizer(["a rat", "a dog", "a cat"],
padding=True,
truncation=True,
max_length=77,
return_tensors='pt')

text = tokenizer_out["input_ids"].to(device)
attention_mask = tokenizer_out["attention_mask"].to(device)
with torch.no_grad():
image_features = model.get_image_features(image)
text_features = model.get_text_features(text)
text_features = model.get_text_features(text, attention_mask=attention_mask)
text_probs = (image_features @ text_features.T).softmax(dim=-1)

print(text_probs.cpu().numpy()[0].tolist())

if __name__=="__main__":
inference()
inference()
1 change: 0 additions & 1 deletion flagai/data/tokenizer/uni_tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@ def add_command_token(self, name, token):
self.num_tokens += 1
self._command_tokens.append(CommandToken(name, token, id))
return

def rematch(self, text, tokens):
"""output the mapping relation between raw text and tokenizezd text
"""
Expand Down
2 changes: 1 addition & 1 deletion flagai/env_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self,

def add_arg(self, arg_name, default=None, type=str, help="", store_true=False):
if store_true:
self.parser.add_argument(f"--{arg_name}", default=default, type=type, action="store_true", help=help)
self.parser.add_argument(f"--{arg_name}", action="store_true", help=help)
else :
self.parser.add_argument(f"--{arg_name}", default=default, type=type, help=help)

Expand Down

0 comments on commit cae3d0d

Please sign in to comment.