Skip to content

Commit

Permalink
使用faiss之前的版本
Browse files Browse the repository at this point in the history
  • Loading branch information
MuskAI committed Apr 26, 2022
1 parent 1d8fc57 commit 1fc6c27
Show file tree
Hide file tree
Showing 27 changed files with 1,473 additions and 91 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ graph LR;
8-->3
```
### **删除功能**

支持通过如下方式进行删除:
1. 删除所有类,e.g.要删除所有苹果的特征,则输入 ```['apple'] or 'apple' ```

删除所有类的实现思路:
1. 先通过find方法,找到要删除的index
2. 在delete方法中,通过find方法返回的index进行删除

### **数据库中保存的项如下**

Expand Down
125 changes: 99 additions & 26 deletions apis/construct_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
功能:从图片目录构建数据库构建数据库,并提供增删改查的接口
"""
import cv2
import faiss
import pandas as pd
import os,shutil
import os,shutil,sys
sys.path.append('./apis')
from pprint import pprint
from hashlib import md5
import numpy as np
import time
from six.moves import cPickle
import re
from infer import infer

from tqdm import tqdm
from model_zoo.mobilenet_v2_md5 import MobileNetV2Feat

Expand Down Expand Up @@ -40,6 +44,8 @@ def __init__(self, img_dir='../database', cache_dir='../cache'):
self.samples = None
self.sample_cache = sample_cache
self.initial_sample_cache = initial_sample_cache
self.model_path = '../model_zoo/checkpoint/mobv2-circleloss-imagenet-e10.onnx'


def create_db(self):
"""
Expand All @@ -56,7 +62,7 @@ def create_db(self):
return None
else:
# 导入模型
method = MobileNetV2Feat(model_path='../model_zoo/checkpoint/ret_mobilenet_v2.onnx')
method = MobileNetV2Feat(model_path=self.model_path,state='insert')
samples = []
# 开始遍历图片
for root, _, files in tqdm(os.walk(img_dir, topdown=False)):
Expand All @@ -67,10 +73,20 @@ def create_db(self):
img = os.path.join(root, name)

# 开始编码
md5_code = self.get_md5(img_path=img)
md5_code ,new_path = self.get_md5(img_path=img)
if new_path is None:
pass
else:
img = new_path

# 目标检测在这个函数里面
query = method.make_single_sample(img, verbose=False, md5_encoding=False)

# 如果没有检出
if query is None:
print('图片{}因没有检出目标而未能加入样本库s'.format(img))
continue

# 开始构建数据库中的项
sample = {
'md5': md5_code,
Expand Down Expand Up @@ -108,20 +124,23 @@ def get_md5(self, img_path):

try:
name = img_path.split('/')[-1]
new_path = None
if 'md5' in name:
# 如果编码已经在文件名中,则可以直接从图片名称中获得编码
md5_code = name.split('-')[-1].split('.')[0]

else:
with open(img_path, 'rb') as img:
md5_code = md5(img.read()).hexdigest()
os.rename(img_path, os.path.join(img_path.replace(name, ''),
new_path = img_path, os.path.join(img_path.replace(name, ''),
'{}-md5-{}.{}'.format(name.split('.')[0], md5_code,
name.split('.')[-1])))
name.split('.')[-1]))
os.rename(new_path)

except Exception as e:
return None

return md5_code
return md5_code, new_path

def insert(self, img_path_list=[], img_label_list=[]):
"""
Expand All @@ -130,21 +149,30 @@ def insert(self, img_path_list=[], img_label_list=[]):
Args:
img_path_list : 图片路径列表 e.g.['./test.jpg','./demo.png']
"""

assert self.samples, 'connect to db error'
walk_time = 0
encode_time = 0
check_duplicate_time = 0
s1=time.time()


# 使用默认的方式开始增量学习
img_info_list = []
if len(img_path_list) == 0 or len(img_label_list) == 0:
# 开始遍历图片,数据库
for root, _, files in tqdm(os.walk(self.img_dir, topdown=False)):
s2 = time.time()
cls = root.split('/')[-1]
for name in files:
if not name.endswith('.png') and not name.endswith('.jpg'):
continue
walk_time += (time.time()-s2)
img = os.path.join(root, name)

# 开始编码, 修改这里还可以提速
s3 = time.time()
md5_code = self.get_md5(img_path=img)
encode_time += (time.time()-s3)

img_info_list.append({
'md5': md5_code,
Expand All @@ -153,14 +181,19 @@ def insert(self, img_path_list=[], img_label_list=[]):
})

# 检查重复项
diffs = self.check_duplicate(samples=self.samples, img_info_list=img_info_list)

s4 = time.time()
diffs = self.check_duplicate(samples=self.samples, img_info_list=img_info_list)
check_duplicate_time = time.time()-s4
# 生成需要新增的图片列表
img_path_list = [i['img'] for i in diffs['insert']]
img_label_list = [i['cls'] for i in diffs['insert']]

method = MobileNetV2Feat(model_path='../model_zoo/checkpoint/ret_mobilenet_v2.onnx')
s5 = time.time()
method = MobileNetV2Feat(model_path=self.model_path)
load_mode_time = time.time() - s5

s6 = time.time()
for idx, item in enumerate(img_path_list):
# 首先判断路径是否存在
if os.path.isfile(item):
Expand All @@ -176,10 +209,17 @@ def insert(self, img_path_list=[], img_label_list=[]):
'error_times': 0 # 出错次数,默认为0
}
self.samples.append(sample)
study_time = time.time() - s6
# 如果有修改则重新保存数据库
s7 = time.time()
if len(img_label_list) != 0 and len(img_label_list) != 0:
cPickle.dump(self.samples, open(self.db_path, "wb", True))
print('\033[41m增量学习成功!数据库已更新!')
print('增量学习成功!数据库已更新!')
save_time = time.time()-s7
total_time = time.time() - s1
print('总用时{:.5f}\n遍历循环用时{:.5f}\n编码用时{:.5f}\n'
'集合差集用时{:.5f}\n加载模型用时{:.5f}\n增量学习用时{:.5f}\n保存文件用时{:.5f}\n'.format(total_time,
walk_time,encode_time,check_duplicate_time,load_mode_time,study_time,save_time))

def check_duplicate(self, samples, img_info_list):
"""
Expand Down Expand Up @@ -233,8 +273,11 @@ def delete(self, type='image',key=None, del_numk=0):
del_numk:仅仅在key=image的时候才会被使用,删除第几个,默认第一个为易混淆的特征,
"""
assert type in ('hist','md5','image'), 'Not support type {} at this time'.format(type)
assert type in ('hist','md5','image','cls'), 'Not support type {} at this time'.format(type)

# 记录一些删除前的状态
len_samples = len(self.samples)
all_cls = []

# 如果type 是 md5 或者hist
if type == 'hist' or type == 'md5':
Expand All @@ -248,17 +291,36 @@ def delete(self, type='image',key=None, del_numk=0):
index, sample = self.find(find_key=del_md5, type='md5')
else:
print('从数据库中删除特征失败!')

if index == None:
# 删除整个类别
elif type == 'cls':
# 在通过类别删除的时候就不需要find,可以输入'youzi'或['youzi','apple']
assert isinstance(key, (list, str)) # 暂时只支持通过'apple' or ['apple','garbage]的方式进行修改
if isinstance(key,str):
key = [key]
for idx, k_item in enumerate(self.samples[::-1]):
all_cls.append(k_item['cls'])
if k_item['cls'] in key:
self.samples.remove(k_item) # 删除
all_cls = set(all_cls)

if type != 'cls' and index == None:
return
try:
_sample = self.samples.pop(index)
# 保存数据库文件
sample_cache = '{}-{}-{}'.format(self.RES_model, self.pick_layer, 'md5')
cPickle.dump(self.samples, open(os.path.join(self.cache_dir, sample_cache), "wb", True))

if type != 'cls':
_sample = self.samples.pop(index)
else:
_sample = key
# TODO 确认是要删除的内容
print('成功从数据库中删除特征 {}'.format(_sample))
if len(self.samples) < len_samples:
# 保存数据库文件
sample_cache = '{}-{}-{}'.format(self.RES_model, self.pick_layer, 'md5')
cPickle.dump(self.samples, open(os.path.join(self.cache_dir, sample_cache), "wb", True))

print('成功从数据库中删除特征 {}'.format(_sample))
elif len(self.samples) == len_samples and type == 'cls':
print('删除失败!,{}类别不在数据库中,请重新输入'.format(key))
print('数据库中的类别为:',all_cls)


except Exception as e:
print('从数据库中删除特征失败!')
Expand All @@ -270,7 +332,7 @@ def find(self, find_key, type='hist'):
"""
集成各种在数据库查找的功能,返回的是在samples中的index 和 内容
"""
assert type in ('hist','md5','image'),'Not support type {} at this time'.format(type)
assert type in ('hist','md5','image','cls'),'Not support type {} at this time'.format(type)
assert self.samples is not None,'Please connect db before you using it.'
index = None
# 如果是根据特征查找
Expand All @@ -284,7 +346,7 @@ def find(self, find_key, type='hist'):
index = _.index(find_key)
elif type == 'image':
assert os.path.isfile(find_key),'{} is not exist!'.format(find_key)
method = MobileNetV2Feat(model_path='../model_zoo/checkpoint/ret_mobilenet_v2.onnx')
method = MobileNetV2Feat(model_path=self.model_path)
query = method.make_single_sample(find_key, verbose=False, md5_encoding=False)

# parameters
Expand All @@ -293,14 +355,25 @@ def find(self, find_key, type='hist'):
d_type = 'd1' # distance type you can choose 'd1 , d2 , d3 ... d8' and 'cosine' and 'square'
top_cls, result, std_result = infer(query, samples=self.samples, depth=topd, d_type=d_type, topk=topk, thr=1)
return std_result
elif type == 'cls':
#由于赶时间,先不考虑效率实现
results = []
for idx,item in enumerate(self.samples):
if item['cls'] == find_key:
results.append(item)


return results



else:
pass
except:
print('未找到要删除的特征,或已被删除')

if index is not None:
return index,self.samples[index]
return index, self.samples[index]
else:
return None,None

Expand Down Expand Up @@ -337,15 +410,15 @@ def __repr__(self):
db = Database()

# 创建数据库,如果数据库文件存在则需要删除才能创建
# db.create_db()
db.create_db()

# 连接数据库,使用之前都需要连接数据库
db.connect_db()
# db.connect_db()
# 新增样本,增量学习
# db.insert()
# 恢复出厂设置
# db.recover_db()

# 删除易混淆的特征,这里的易混淆特征是指:比如苹果经常被识别成土豆,就删除
# db.delete(type='md5',key='97becba942a9829b6fd9187a004740dd')
db.delete(type='image', key='/Users/musk/PycharmProjects/shengxian_retrieval_onnx/database/bailuobo/20211115-114804WCP-md5-2ff80b4530ce5aea4b20a52dc0611060.png')
# db.delete(type='cls',key='youzi')
# db.delete(type='image', key='/Users/musk/PycharmProjects/shengxian_retrieval_onnx/database/bailuobo/20211115-114804WCP-md5-2ff80b4530ce5aea4b20a52dc0611060.png')
45 changes: 45 additions & 0 deletions apis/faiss_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :shengxian_retrieval_onnx
@File :faiss_demo.py
@IDE :PyCharm
@Author :haoran
@Date :2022/4/25 17:19
'''
import numpy as np
import time
import faiss





if __name__ == '__main__':
d = 1280 # dimension
nb = 20000 # database size
nq = 1 # nb of queries
np.random.seed(1234) # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xb_len = np.linalg.norm(xb, axis=1, keepdims=True)
xb = xb / xb_len
xq = np.random.random((nq, d)).astype('float32')
xq_len = np.linalg.norm(xq, axis=1, keepdims=True)
xq = xq / xq_len
# CPU

t1 = time.time()
# index = faiss.IndexFlat(d, faiss.METRIC_INNER_PRODUCT) # 建立索引
index = faiss.IndexFlatIP(d)
# 或者通过faiss.indexFlatIP(内积)实现
index.add(xb) # add vectors to the index
nlist = 20000 # we want to see 4 nearest neighbors
for i in range(10):

D, I = index.search(-xq, nlist) # actual search
t2 = time.time()
print(D[0] + [1 for i in range(len(D[0]))],I)
# print(D[0],I)

print('faiss spend time %.4f' % ((t2 - t1) / 10))

45 changes: 45 additions & 0 deletions apis/faiss_infer_kmean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :shengxian_retrieval_onnx
@File :faiss_demo.py
@IDE :PyCharm
@Author :haoran
@Date :2022/4/25 17:19
'''
import numpy as np
import time
import faiss





if __name__ == '__main__':
d = 1280 # dimension
nb = 20000 # database size
nq = 1 # nb of queries
np.random.seed(1234) # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xb_len = np.linalg.norm(xb, axis=1, keepdims=True)
xb = xb / xb_len
xq = np.random.random((nq, d)).astype('float32')
xq_len = np.linalg.norm(xq, axis=1, keepdims=True)
xq = xq / xq_len

# CPU
quantizer = faiss.IndexFlatIP(d)
# index = faiss.IndexFlat(d, faiss.METRIC_INNER_PRODUCT) # 建立索引
index = faiss.IndexIVFFlat(quantizer, d, 50, faiss.METRIC_INNER_PRODUCT)
index.train(xb)
index.add(xb) # add vectors to the index
t1 = time.time()
nlist = 5 # we want to see 4 nearest neighbors
for i in range(10):

# 或者通过faiss.indexFlatIP(内积)实现

D, I = index.search(xq, nlist) # actual search
t2 = time.time()
print('faiss spend time %.4f' % ((t2 - t1) / 10))

Loading

0 comments on commit 1fc6c27

Please sign in to comment.