diff --git a/examples/linreg.py b/examples/linreg.py index 46aead7..2a7e216 100644 --- a/examples/linreg.py +++ b/examples/linreg.py @@ -21,7 +21,7 @@ # Form the affine registration instance. affine = register.Register( model.Affine, - metric.Residual, + metric.Dssim, sampler.Bilinear ) diff --git a/examples/nonlinreg.py b/examples/nonlinreg.py index 2d20f7c..b393d29 100644 --- a/examples/nonlinreg.py +++ b/examples/nonlinreg.py @@ -53,7 +53,7 @@ def warp(image): # Form the spline registration instance. spline = register.Register( model.CubicSpline, - metric.Residual, + metric.Dssim, sampler.CubicConvolution ) diff --git a/register/metrics/metric.py b/register/metrics/metric.py index 5bce219..afe4d0c 100644 --- a/register/metrics/metric.py +++ b/register/metrics/metric.py @@ -1,6 +1,9 @@ """ A collection of image similarity metrics. """ - +from ssim import calc_ssim, imageshash, flush_work import numpy as np +import ssim + + class Metric(object): """ @@ -133,3 +136,86 @@ def error(self, warpedImage, template): Metric evaluated over all image coordinates. """ return warpedImage.flatten() - template.flatten() + + + +def _dssim(image, reference, window): + hashkey = imageshash(image,reference) + flush_work() + dssim = np.zeros(image.shape) + for j in range(image.shape[0]): + # print 100.0 * j / image.shape[0] + for i in range(image.shape[1]): + bbox = [max(0,i-window/2), max(0,j-window/2), min(image.shape[1], i+window/2), min(image.shape[0], j+window/2)] + dssim[j,i] = calc_ssim(image, reference, bbox, direct=False, imagekey=hashkey) + flush_work() + return dssim + + +class Dssim(Metric): + """ DSSIM metric """ + + METRIC='dssim' + + DESCRIPTION=""" + The error metric which is computed as the DSSIM between the + deformed image and the template: + + DSSIM(I(W(x;p)),T) + + """ + + def __init__(self): + Metric.__init__(self) + + def jacobian(self, model, warpedImage, p=None): + """ + Computes the jacobian dP/dE. + + Parameters + ---------- + model: deformation model + A particular deformation model. + warpedImage: nd-array + Input image after warping. + p : optional list + Current warp parameters + + Returns + ------- + jacobian: nd-array + A jacobain matrix. (m x n) + | where: m = number of image pixels, + | p = number of parameters. + """ + + grad = np.gradient(warpedImage) + + dIx = grad[1].flatten() + dIy = grad[0].flatten() + + dPx, dPy = model.jacobian(p) + + J = np.zeros_like(dPx) + for index in range(0, dPx.shape[1]): + J[:,index] = dPx[:,index]*dIx + dPy[:,index]*dIy + return J + + def error(self, warpedImage, template): + """ + Evaluates the DSSIM metric. + + Parameters + ---------- + warpedImage: nd-array + Input image after warping. + template: nd-array + Template image. + + Returns + ------- + error: nd-array + Metric evaluated over all image coordinates. + """ + d = _dssim(warpedImage, template, 11) + return d.flatten() diff --git a/register/metrics/ssim.py b/register/metrics/ssim.py new file mode 100644 index 0000000..b0fd77a --- /dev/null +++ b/register/metrics/ssim.py @@ -0,0 +1,134 @@ +import numpy as np +import hashlib +import sys +import os + + +""" See http://en.wikipedia.org/wiki/Structural_similarity. + + Also: Region Covariance: A fast descriptor for detection and classification. +""" + + +def ssim_direct(image, reference, bbox=None): + if bbox is None: + bbox = [0, 0, image.shape[1], image.shape[0]] + img = image[bbox[1]:bbox[3], bbox[0]:bbox[2]].astype(np.uint8) + ref = reference[bbox[1]:bbox[3], bbox[0]:bbox[2]].astype(np.uint8) + + meanI = np.mean(img.flatten()) + meanR = np.mean(ref.flatten()) + varI = np.cov(img.flatten(), img.flatten())[0,1] + varR = np.cov(ref.flatten(), ref.flatten())[0,1] + covIR = np.cov(img.flatten(), ref.flatten())[0,1] + L = (2.0**8)-1.0 + k1 = 0.01 + k2 = 0.03 + c1 = (k1*L)**2.0 + c2 = (k2*L)**2.0 + SSIM = ((2.0*meanI*meanR + c1) * (2.0*covIR + c2)) / ((meanI**2.0 + meanR**2.0 + c1) * (varI + varR + c2)) + return SSIM + +_saved_work = {} + +def flush_work(): + """ Flush all saved work for previously processed images from memory. + """ + global _saved_work + _saved_work = {} + +def _integral(image): + """ Calculates the integral image of the image provided. + + @param image: The input image (MxN) + @return: The integral image (MxN) of type float + """ + sum = 0.0 + result = np.empty_like(image).astype(np.float) + for row in range(image.shape[0]): + for col in range(image.shape[1]): + sum += image[row,col] + result[row, col] = sum + return result + +def imageshash(image, reference): + """ Determine the image pair's hash. + """ + return hashlib.md5(image.dumps()+reference.dumps()).hexdigest() + +def _p_Q(stack): + """ Calculates the p and Q matrices using the image stack. + + Derivation in: Region Covariance - A Fast Descriptor for Detection + and Classification, Oncel Tuzel, Fatih Porikli, Peter Meer + + @param stack: The image stack (MxNx2) + @return: The p and Q matices + """ + d = stack.shape[2] + x = stack.shape[1] + y = stack.shape[0] + p_xy = np.empty([d,x,y]) + Q_xy = np.empty([d,d,x,y]) + for i in range(d): + p_xy[i,:,:] = _integral(stack[:,:,i]) + for j in range(d): + Q_xy[i,j,:,:] = _integral(stack[:,:,i]*stack[:,:,j]) + return (p_xy, Q_xy) + + +def ssim_fast(image, reference, bbox=None, imagekey=None): + if bbox is None: + bbox = [0, 0, image.shape[1], image.shape[0]] + #assert len(bbox) == 4 + upperleft = [0,0] + lowerright = [0,0] + upperleft[0] = bbox[1] + upperleft[1] = bbox[0] + lowerright[0] = bbox[3] - 1 + lowerright[1] = bbox[2] - 1 + if imagekey is None: + imagekey=imageshash(image,reference) + if not _saved_work.has_key(imagekey): + stack = np.dstack([image, reference]) + p, Q = _p_Q(stack) + _saved_work[imagekey] = {'p':p, 'Q':Q} + else: + p = _saved_work[imagekey]['p'] + Q = _saved_work[imagekey]['Q'] + QQQQ = Q[:,:,lowerright[1], lowerright[0]] + Q[:,:,upperleft[1], upperleft[0]] - Q[:,:,upperleft[1], lowerright[0]] - Q[:,:,lowerright[1], upperleft[0]] + pppp = p[:,lowerright[1], lowerright[0]] + p[:,upperleft[1], upperleft[0]] - p[:,upperleft[1], lowerright[0]] - p[:,lowerright[1], upperleft[0]] + n = (lowerright[1] - upperleft[1])*(lowerright[0] - upperleft[0]) + C = 1.0 / (n - 1.0) * (QQQQ - 1.0 / n * pppp*pppp.T) + meanI = 1.0 / n * pppp[0] + meanR = 1.0 / n * pppp[1] + varI = C[0,0] + varR = C[1,1] + covIR = C[0,1] + L = (2.0**8)-1.0 + k1 = 0.01 + k2 = 0.03 + c1 = (k1*L)**2.0 + c2 = (k2*L)**2.0 + SSIM = ((2.0*meanI*meanR + c1) * (2.0*covIR + c2)) / ((meanI**2.0 + meanR**2.0 + c1) * (varI + varR + c2)) + return SSIM + + +def calc_ssim(image, reference, bbox=None, direct=True, imagekey=None): + assert image.shape == reference.shape + if direct: + return ssim_direct(image, reference, bbox) + else: + return ssim_fast(image, reference, bbox, imagekey) + +if __name__ == "__main__": + import osgeo.gdal as gdal + dsImage = gdal.Open(sys.argv[1]) + dsReference = gdal.Open(sys.argv[2]) + try: + fast = sys.argv[3] == '--fast' + except: + fast = False + image = dsImage.GetRasterBand(1).ReadAsArray() + reference = dsReference.GetRasterBand(1).ReadAsArray() + print "SSIM is: ", ssim(image, reference, direct=(not fast)) diff --git a/register/register.py b/register/register.py index 4b62804..802af9f 100644 --- a/register/register.py +++ b/register/register.py @@ -169,7 +169,7 @@ class Register(object): optStep = collections.namedtuple('optStep', 'error p deltaP') MAX_ITER = 200 - MAX_BAD = 20 + MAX_BAD = 50 def __init__(self, model, metric, sampler): @@ -319,7 +319,8 @@ def register(self, warpedImage, image.coords.tensor, warp, - '{0}:{1}'.format(model.MODEL, itteration) + '{0}:{1}'.format(model.MODEL, itteration), + metric ) else: badSteps += 1 diff --git a/register/visualize/plot.py b/register/visualize/plot.py index 41dfa8d..5205b96 100644 --- a/register/visualize/plot.py +++ b/register/visualize/plot.py @@ -128,7 +128,7 @@ def featurePlotSingle(image): featurePlt(image.features) -def gridPlot(image, template, warpedImage, grid, warp, title): +def gridPlot(image, template, warpedImage, grid, warp, title, metric=None): plt.subplot(2,3,1) plt.title('I') @@ -161,8 +161,14 @@ def gridPlot(image, template, warpedImage, grid, warp, title): plt.axis('off') plt.subplot(2,3,4) - plt.title('I-T') - plt.imshow(template - image, + plt.title('E(I,T)') + if metric is None: + plt.imshow(template - image, + origin=IMAGE_ORIGIN, + cmap=IMAGE_COLORMAP + ) + else: + plt.imshow(metric.error(image,template).reshape(image.shape), origin=IMAGE_ORIGIN, cmap=IMAGE_COLORMAP ) @@ -175,11 +181,17 @@ def gridPlot(image, template, warpedImage, grid, warp, title): plt.title('W(x;p)') plt.subplot(2,3,6) - plt.title('W(I;p) - T {0}'.format(title)) - plt.imshow(template - warpedImage, - origin=IMAGE_ORIGIN, + plt.title('E(W(I;p),T) {0}'.format(title)) + if metric is None: + plt.imshow(template - warpedImage, + origin=IMAGE_ORIGIN, cmap=IMAGE_COLORMAP - ) + ) + else: + plt.imshow(metric.error(warpedImage,template).reshape(image.shape), + origin=IMAGE_ORIGIN, + cmap=IMAGE_COLORMAP + ) plt.axis('off') plt.draw()