优秀的编程知识分享平台

网站首页 > 技术文章 正文

可视化搜索引擎和机器学习技术索引Python实例

nanyue 2024-09-09 04:56:33 技术文章 6 ℃

可视化搜索引擎由大多数搜索引擎提供,视觉搜索也可用于电子商务空间或商店体验。我展示了我穿的一些东西,并且我也推荐了我正在寻找的类似东西。

这个项目主要有两部分:首先是通过对我们所拥有的数据进行训练来找到给定图像的相似图像。然后只给用户最相似的产品。

我们将通过使用预先训练的网络作为特征来计算图像嵌入来对图像数据集进行索引。然后我们可以使用K-NN算法查询数据集。

如上图所示,为所有输入图像创建索引。然后,当用户查询给定的图像,然后提取特征,并获得最接近给定图像的最近图像。

依赖:

  • 安装mxnet

  • hnswlib(按照指南在这里:https : //github.com/nmslib/hnsw)

  • 原始数据来自:http : //jmcauley.ucsd.edu/data/amazon/

入门:

导入模块:

import mxnet as mx

from mxnet import gluon, nd

from mxnet.gluon.model_zoo import vision

import multiprocessing

from mxnet.gluon.data.vision.datasets import ImageFolderDataset

from mxnet.gluon.data import DataLoader

import numpy as np

import wget

import imghdr

import json

import pickle

import hnswlib

import numpy as np

import glob, os, time

import matplotlib.pyplot as plt

import matplotlib.gridspec as gridspec

from urlparse import urlparse

import urllib

import gzip

%matplotlib inline

下载数据:

这是来自亚马逊的数据子集,由斯坦福大学提供

data_path = 'metadata.json'

images_path = '/data/amazon_images_subset'

if not os.path.isfile(data_path):

# Downloading the metadata, 3.1GB, unzipped 9GB

!wget -nv https://s3.us-east-2.amazonaws.com/mxnet-public/stanford_amazon/metadata.json.gz

!gzip -d metadata.json.gz

if not os.path.isdir(images_path):

os.makedirs(images_path)

处理数据:

实际下载的图像大小约为9 GB,但对于初学者,我们可以使用更小的数据子集进行处理。

subset_num = 1000

num_lines = 0

num_lines = sum(1 for line in open(data_path))

assert num_lines >= subset_num, "Subset needs to be smaller or equal to total number of example"

已下载的数据包含可以解析的图像的url链接,如下所示:

def parse(path, num_cpu, modulo):

g = open(path, 'r')

for i, l in enumerate(g):

if (i >= num_lines - subset_num and i%num_cpu == modulo):

yield eval(l)

从数据获取图像

NUM_CPU = multiprocessing.cpu_count()*10

def download_files(modulo):

for data in parse(data_path, NUM_CPU, modulo):

if 'imUrl' in data and data['imUrl'] is not None and 'categories' in data and data['imUrl'].split('.')[-1] == 'jpg':

url = data['imUrl']

try:

path = os.path.join(images_path, data['asin']+'.jpg')

if not os.path.isfile(path):

file = urllib.request.urlretrieve(url, path)

except:

print("Error downloading {}".format(url))

保存图像并删除假图像

pool = multiprocessing.Pool(processes=NUM_CPU)

results = pool.map(download_files, list(range(NUM_CPU)))

# Removing all the fake jpegs

list_files = glob.glob(os.path.join(images_path, '**.jpg'))

for file in list_files:

if imghdr.what(file) != 'jpeg':

print('Removed {} it is a {}'.format(file, imghdr.what(file)))

os.remove(file)

生成图像嵌入:

BATCH_SIZE = 256

EMBEDDING_SIZE = 512

SIZE = (224, 224)

MEAN_IMAGE= mx.nd.array([0.485, 0.456, 0.406])

STD_IMAGE = mx.nd.array([0.229, 0.224, 0.225])

Featurizer

我们使用model zoo的预训练模型

ctx = mx.cpu()

net = vision.resnet18_v2(pretrained=True, ctx=ctx)

net = net.features

数据转换

将图像转换成网络可用的形状

def transform(image, label):

resized = mx.image.resize_short(image, SIZE[0]).astype('float32')

cropped, crop_info = mx.image.center_crop(resized, SIZE)

cropped /= 255.

normalized = mx.image.color_normalize(cropped,

mean=MEAN_IMAGE,

std=STD_IMAGE)

transposed = nd.transpose(normalized, (2,0,1))

return transposed, label

数据加载

import os, tempfile, glob

empty_folder = tempfile.mkdtemp()

# Create an empty image Folder Data Set

dataset = ImageFolderDataset(root=empty_folder, transform=transform)

list_files = glob.glob(os.path.join(images_path, '**.jpg'))

dataset.items = list(zip(list_files, [0]*len(list_files)))

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, last_batch='keep', shuffle=False, num_workers=multiprocessing.cpu_count())

Featurization

features = np.zeros((len(dataset), EMBEDDING_SIZE), dtype=np.float32)

%%time

tick = time.time()

n_print = 100

j = 0

net.hybridize()

for i, (data, label) in enumerate(dataloader):

data = data.as_in_context(ctx)

if i%n_print == 0 and i > 0:

print("{0} batches, {1} images, {2:.3f} img/sec".format(i, i*BATCH_SIZE, BATCH_SIZE*n_print/(time.time()-tick)))

tick = time.time()

output = net(data)

features[(i)*BATCH_SIZE:(i+1)*max(BATCH_SIZE, len(output)), :] = output.asnumpy().squeeze()

创建搜索索引

# Number of elements in the index

num_elements = len(features)

labels_index = np.arange(num_elements)

%%time

# Declaring index

p = hnswlib.Index(space = 'l2', dim = EMBEDDING_SIZE) # possible options are l2, cosine or ip

# Initing index - the maximum number of elements should be known beforehand

p.init_index(max_elements = num_elements, ef_construction = 100, M = 16)

# Element insertion (can be called several times):

int_labels = p.add_items(features, labels_index)

# Controlling the recall by setting ef:

p.set_ef(100) # ef should always be > k

p.save_index('index.idx')

测试

我们通过从数据集中抽取随机图像并搜索他们的K-NN来测试结果

def plot_predictions(images):

gs = gridspec.GridSpec(3, 3)

fig = plt.figure(figsize=(15, 15))

gs.update(hspace=0.1, wspace=0.1)

for i, (gg, image) in enumerate(zip(gs, images)):

gg2 = gridspec.GridSpecFromSubplotSpec(10, 10, subplot_spec=gg)

ax = fig.add_subplot(gg2[:,:])

ax.imshow(image, cmap='Greys_r')

ax.tick_params(axis='both',

which='both',

bottom='off',

top='off',

left='off',

right='off',

labelleft='off',

labelbottom='off')

ax.axes.set_title("result [{}]".format(i))

if i == 0:

plt.setp(ax.spines.values(), color='red')

ax.axes.set_title("SEARCH".format(i))

def search(N, k):

# Query dataset, k - number of closest elements (returns 2 numpy arrays)

q_labels, q_distances = p.knn_query([features[N]], k = k)

images = [plt.imread(dataset.items[label][0]) for label in q_labels[0]]

plot_predictions(images)

一些测试

%%time

index = np.random.randint(0,len(features))

k = 6

search(index, k)

Tags:

最近发表
标签列表