diff --git a/.gitignore b/.gitignore index b6f9e1ee74f4..6c28913daec7 100644 --- a/.gitignore +++ b/.gitignore @@ -71,6 +71,7 @@ doc/gallery doc/modules doc/plot_types doc/pyplots/tex_demo.png +doc/image_search/ doc/tutorials doc/users/explain lib/dateutil diff --git a/doc/_static/image_search.css b/doc/_static/image_search.css new file mode 100644 index 000000000000..3eec29741aaa --- /dev/null +++ b/doc/_static/image_search.css @@ -0,0 +1,3 @@ +.sphx-glr-imgsearch-resultelement{ + display: none; +} \ No newline at end of file diff --git a/doc/_static/image_search.js b/doc/_static/image_search.js new file mode 100644 index 000000000000..b0bdcd0e44e4 --- /dev/null +++ b/doc/_static/image_search.js @@ -0,0 +1,55 @@ +function cosineSimilarity(vec1, vec2) { + const dotProduct = vec1.map((val, i) => val * vec2[i]).reduce((accum, curr) => accum + curr, 0); + const vec1Size = calcVectorSize(vec1); + const vec2Size = calcVectorSize(vec2); + + return dotProduct / (vec1Size * vec2Size); +}; + +function calcVectorSize(vec) { + return Math.sqrt(vec.reduce((accum, curr) => accum + Math.pow(curr, 2), 0)); +}; + + +data = [] +fetch('/_static/data.json') + .then( r => r.json() ) + .then( d => { data = d } ) + + +function handle_search() { + if( data.length == 0 ){ + return; + } + + const container = document.getElementById('sphx-glr-imgsearchresult-container') + container.innerHTML = "" + + result = {} + for (const [key, value] of data ) { + // just find the similar images to the image at the beginning of data + cos = cosineSimilarity( data[0][1], value) + result[cos] = key + } + + result = Object.keys(result).sort().reduce( + (obj, key) => { + obj[key] = result[key]; + return obj; + }, + {} + ); + + + Object.entries(result).map( ([key, value], index) => { + if( index > 5 ) return + const id = value; + const elem = document.getElementById( id ); + container.innerHTML += elem.innerHTML + } ) + +} + +window.addEventListener( 'load', () => { + document.getElementById('sphx-glr-imgsearchbutton').addEventListener( 'click', handle_search ) +} ) \ No newline at end of file diff --git a/doc/conf.py b/doc/conf.py index bc9b1ff7c1fa..5b9e41f9a0c4 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -124,6 +124,7 @@ def _parse_skip_subdirs_file(): 'sphinxext.mock_gui_toolkits', 'sphinxext.skip_deprecated', 'sphinxext.redirect_from', + 'sphinxext.image_search', 'sphinx_copybutton', 'sphinx_design', 'sphinx_tags', @@ -245,7 +246,7 @@ def matplotlib_reduced_latex_scraper(block, block_vars, gallery_conf, return matplotlib_scraper(block, block_vars, gallery_conf, **kwargs) gallery_dirs = [f'{ed}' for ed in - ['gallery', 'tutorials', 'plot_types', 'users/explain'] + ['gallery', 'tutorials', 'plot_types', 'users/explain', 'image_search'] if f'{ed}/*' not in skip_subdirs] example_dirs = [] @@ -477,6 +478,10 @@ def js_tag_with_cache_busting(js): "mpl.css", ] +html_js_files = [ + "image_search.js" +] + html_theme = "mpl_sphinx_theme" # The name for this set of Sphinx documents. If None, it defaults to diff --git a/doc/sphinxext/image_search.py b/doc/sphinxext/image_search.py new file mode 100644 index 000000000000..73868190ed92 --- /dev/null +++ b/doc/sphinxext/image_search.py @@ -0,0 +1,257 @@ +import os +import json +import pandas as pd +import numpy as np +import torch +import timm + +from xml.sax.saxutils import escape +from PIL import Image +from tqdm import tqdm +from torchvision import transforms +from torch.autograd import Variable + +from sphinx.util import logging as sphinx_logging +from sphinx.errors import ExtensionError +from sphinx_gallery import gen_gallery +from sphinx_gallery.py_source_parser import split_code_and_text_blocks +from sphinx_gallery.gen_rst import extract_intro_and_title +from sphinx_gallery.backreferences import BACKREF_THUMBNAIL_TEMPLATE, _thumbnail_div, THUMBNAIL_PARENT_DIV, THUMBNAIL_PARENT_DIV_CLOSE +from sphinx_gallery.scrapers import _find_image_ext + + +logger = sphinx_logging.getLogger(__name__) + + + +class SearchSetup: + """ A class for setting up and generating image vectors.""" + def __init__(self, model_name='vgg19', pretrained=True): + """ + Parameters: + ----------- + image_list : list + A list of images to be indexed and searched. + model_name : str, optional (default='vgg19') + The name of the pre-trained model to use for feature extraction. + pretrained : bool, optional (default=True) + Whether to use the pre-trained weights for the chosen model. + image_count : int, optional (default=None) + The number of images to be indexed and searched. If None, all images in the image_list will be used. + """ + self.model_name = model_name + self.pretrained = pretrained + self.image_data = pd.DataFrame() + self.d = None + self.queue = [] + + base_model = timm.create_model(self.model_name, pretrained=self.pretrained) + self.model = torch.nn.Sequential(*list(base_model.children())[:-1]) + self.model.eval() # disables gradient computation + + + def _extract(self, img): + # Resize and convert the image + img = img.resize((224, 224)) + img = img.convert('RGB') + + # Preprocess the image + preprocess = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224, 0.225]), + ]) + x = preprocess(img) + x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False) + + # Extract features + feature = self.model(x) + feature = feature.data.numpy().flatten() + return feature / np.linalg.norm(feature) + + def _get_feature(self, image_data: list): + self.image_data = image_data + features = [] + for img_path in tqdm(self.image_data): # Iterate through images + # Extract features from the image + try: + feature = self._extract(img=Image.open(img_path)) + print(feature) + features.append(feature) + except: + # If there is an error, append None to the feature list + features.append(None) + continue + return features + + def add_image( self, thumbnail_id, image_path ): + + self.queue.append( (thumbnail_id, image_path) ) + + def start_feature_extraction(self): + data_df = pd.DataFrame() + + image_paths = list( map( lambda x:x[1], self.queue ) ) + data_df['image_path'] = image_paths + + features = self._get_feature(image_paths) + data_df['feature'] = features + + data_df['thumbnail_id'] = list( map( lambda x:x[0], self.queue ) ) + + f = open('./_static/data.json', "w") + data_json = [] + for i in range(len(data_df)): + data_json.append( [ data_df.loc[i, "thumbnail_id"], data_df.loc[i, "feature"].tolist() ] ) + + f.write(json.dumps(data_json)) + + + + + +# id="imgsearchref-{ref_name}" attribute is used by the js file later +# to programmatically hide or unhide thumbnails depending on search result +THUMBNAIL_TEMPLATE = """ +.. raw:: html + +