diff --git a/examples/AltCLIP/README.md b/examples/AltCLIP/README.md index 066c0198..d5ca7ea1 100644 --- a/examples/AltCLIP/README.md +++ b/examples/AltCLIP/README.md @@ -219,37 +219,42 @@ from PIL import Image from flagai.auto_model.auto_loader import AutoLoader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -## 一行代码直接自动下载权重到'./checkpoints/clip-xlmr-large',并自动加载CLIP模型权重 -## modelhub地址: Modelhub(https://model.baai.ac.cn/models) + loader = AutoLoader( task_name="txt_img_matching", - model_dir="./checkpoints", - model_name="AltCLIP-XLMR-L" + model_name="AltCLIP-XLMR-L", # Load the checkpoints from Modelhub(model.baai.ac.cn/models) + model_dir="./checkpoints" ) -## 获取加载好的模型 + model = loader.get_model() -## 获取tokenizer tokenizer = loader.get_tokenizer() -## 获取transform用来处理图像 transform = loader.get_transform() model.eval() model.to(device) +tokenizer = loader.get_tokenizer() -## 推理过程,图像与文本匹配 -image = Image.open("./dog.jpeg") -image = transform(image) -image = torch.tensor(image["pixel_values"]).to(device) -text = tokenizer(["a rat", "a dog", "a cat"])["input_ids"] - -text = torch.tensor(text).to(device) - -with torch.no_grad(): - image_features = model.get_image_features(image) - text_features = model.get_text_features(text) - text_probs = (image_features @ text_features.T).softmax(dim=-1) - -print(text_probs.cpu().numpy()[0].tolist()) +def inference(): + image = Image.open("./dog.jpeg") + image = transform(image) + image = torch.tensor(image["pixel_values"]).to(device) + tokenizer_out = tokenizer(["a rat", "a dog", "a cat"], + padding=True, + truncation=True, + max_length=77, + return_tensors='pt') + + text = tokenizer_out["input_ids"].to(device) + attention_mask = tokenizer_out["attention_mask"].to(device) + with torch.no_grad(): + image_features = model.get_image_features(image) + text_features = model.get_text_features(text, attention_mask=attention_mask) + text_probs = (image_features @ text_features.T).softmax(dim=-1) + + print(text_probs.cpu().numpy()[0].tolist()) + +if __name__=="__main__": + inference() ``` ## CLIP微调/Finetuning @@ -271,6 +276,7 @@ from torchvision.datasets import ( ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset_root = "./clip_benchmark_datasets" dataset_name = "cifar10" @@ -279,7 +285,7 @@ classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'hors auto_loader = AutoLoader( task_name="txt_img_matching", - model_dir="./checkpoints/", + model_dir="./checkpoints", model_name="AltCLIP-XLMR-L" # Load the checkpoints from Modelhub(model.baai.ac.cn/models) ) @@ -305,19 +311,26 @@ def cifar10_collate_fn(batch): # image shape is (batch, 3, 224, 224) images = torch.tensor([b[0]["pixel_values"][0] for b in batch]) # text_id shape is (batch, n) - input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",padding=True,truncation=True,max_length=77)["input_ids"] for b in batch]) + input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}", + padding=True, + truncation=True, + max_length=77)["input_ids"] for b in batch]) + + attention_mask = torch.tensor([tokenizer(f"a photo of a {b[1]}", + padding=True, + truncation=True, + max_length=77)["attention_mask"] for b in batch]) return { "pixel_values": images, - "input_ids": input_ids + "input_ids": input_ids, + "attention_mask": attention_mask, } if __name__ == "__main__": trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn) ``` - - ## 模型验证/Evaluation 我们提供了可以直接运行的验证脚本,在cifar10数据集上进行验证。 diff --git a/examples/AltCLIP/altclip_finetuning.py b/examples/AltCLIP/altclip_finetuning.py index 8872387d..2b95fd4c 100644 --- a/examples/AltCLIP/altclip_finetuning.py +++ b/examples/AltCLIP/altclip_finetuning.py @@ -45,11 +45,20 @@ def cifar10_collate_fn(batch): # image shape is (batch, 3, 224, 224) images = torch.tensor([b[0]["pixel_values"][0] for b in batch]) # text_id shape is (batch, n) - input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",padding=True,truncation=True,max_length=77)["input_ids"] for b in batch]) + input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}", + padding=True, + truncation=True, + max_length=77)["input_ids"] for b in batch]) + + attention_mask = torch.tensor([tokenizer(f"a photo of a {b[1]}", + padding=True, + truncation=True, + max_length=77)["attention_mask"] for b in batch]) return { "pixel_values": images, - "input_ids": input_ids + "input_ids": input_ids, + "attention_mask": attention_mask, } if __name__ == "__main__": diff --git a/examples/AltCLIP/altclip_inference.py b/examples/AltCLIP/altclip_inference.py index c737bf9b..6ba53e41 100644 --- a/examples/AltCLIP/altclip_inference.py +++ b/examples/AltCLIP/altclip_inference.py @@ -19,19 +19,23 @@ tokenizer = loader.get_tokenizer() def inference(): - image = Image.open("/home/yanzhaodong/anhforth/data/images/12.png") + image = Image.open("./dog.jpeg") image = transform(image) image = torch.tensor(image["pixel_values"]).to(device) - text = tokenizer(["a rat", "a dog", "a cat"])["input_ids"] - - text = torch.tensor(text).to(device) - + tokenizer_out = tokenizer(["a rat", "a dog", "a cat"], + padding=True, + truncation=True, + max_length=77, + return_tensors='pt') + + text = tokenizer_out["input_ids"].to(device) + attention_mask = tokenizer_out["attention_mask"].to(device) with torch.no_grad(): image_features = model.get_image_features(image) - text_features = model.get_text_features(text) + text_features = model.get_text_features(text, attention_mask=attention_mask) text_probs = (image_features @ text_features.T).softmax(dim=-1) print(text_probs.cpu().numpy()[0].tolist()) if __name__=="__main__": - inference() + inference() \ No newline at end of file diff --git a/flagai/data/tokenizer/uni_tokenizer/tokenizer.py b/flagai/data/tokenizer/uni_tokenizer/tokenizer.py index b12a494e..a5b38477 100644 --- a/flagai/data/tokenizer/uni_tokenizer/tokenizer.py +++ b/flagai/data/tokenizer/uni_tokenizer/tokenizer.py @@ -367,7 +367,6 @@ def add_command_token(self, name, token): self.num_tokens += 1 self._command_tokens.append(CommandToken(name, token, id)) return - def rematch(self, text, tokens): """output the mapping relation between raw text and tokenizezd text """ diff --git a/flagai/env_args.py b/flagai/env_args.py index 49c5ce29..2cb72066 100644 --- a/flagai/env_args.py +++ b/flagai/env_args.py @@ -95,7 +95,7 @@ def __init__(self, def add_arg(self, arg_name, default=None, type=str, help="", store_true=False): if store_true: - self.parser.add_argument(f"--{arg_name}", default=default, type=type, action="store_true", help=help) + self.parser.add_argument(f"--{arg_name}", action="store_true", help=help) else : self.parser.add_argument(f"--{arg_name}", default=default, type=type, help=help)