numpy - Memory Efficient L2 norm using Python broadcasting -


i trying implement way cluster points in test dataset based on similarity sample dataset, using euclidean distance. test dataset has 500 points, each point n dimensional vector (n=1024). training dataset has around 10000 points , each point 1024- dim vector. goal find l2-distance between each test point , sample points find closest sample (without using python distance functions). since test array , training array have different sizes, tried using broadcasting:

    import numpy np     dist = np.sqrt(np.sum( (test[:,np.newaxis] - train)**2, axis=2)) 

where test array of shape (500,1024) , train array of shape (10000,1024). getting memoryerror. however, same code works smaller arrays. example:

     test= np.array([[1,2],[3,4]])      train=np.array([[1,0],[0,1],[1,1]]) 

is there more memory efficient way above computation without loops? based on posts online, can implement l2- norm using matrix multiplication sqrt(x * x-2*x * y+y * y). tried following:

    x2 = np.dot(test, test.t)     y2 = np.dot(train,train.t)     xy = 2* np.dot(test,train.t)      dist = np.sqrt(x2 - xy + y2) 

since matrices have different shapes, when tried broadcast, there dimension mismatch , not sure right way broadcast (dont have experience python broadcasting). know right way implement l2 distance computation matrix multiplication in python, matrices have different shapes. resultant distance matrix should have dist[i,j] = euclidean distance between test point , sample point j.

thanks

here broadcasting shapes of intermediates made explicit:

m = x.shape[0] # x has shape (m, d) n = y.shape[0] # y has shape (n, d) x2 = np.sum(x**2, axis=1).reshape((m, 1)) y2 = np.sum(y**2, axis=1).reshape((1, n)) xy = x.dot(y.t) # shape (m, n) dists = np.sqrt(x2 + y2 - 2*xy) # shape (m, n) 

the documentation on broadcasting has pretty examples.


Comments

Popular posts from this blog

html - Outlook 2010 Anchor (url/address/link) -

javascript - Why does running this loop 9 times take 100x longer than running it 8 times? -

Getting gateway time-out Rails app with Nginx + Puma running on Digital Ocean -