from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
X, y = make_blobs(n_samples=100,
centers=4,
cluster_std=1,
random_state=10
)
plt.scatter(X[:, 0], X[:, 1]);
from sklearn.cluster import KMeans
km = KMeans(n_clusters = 4)
km.fit(X)
km.labels_
array([3, 1, 3, 1, 0, 2, 2, 3, 0, 2, 3, 2, 0, 2, 0, 2, 0, 1, 0, 2, 2, 2, 1, 3, 2, 2, 2, 0, 3, 1, 2, 2, 0, 2, 2, 0, 3, 1, 3, 1, 1, 1, 1, 0, 3, 3, 2, 0, 0, 0, 0, 0, 2, 3, 0, 0, 2, 1, 1, 3, 0, 1, 1, 2, 2, 1, 0, 1, 0, 3, 3, 2, 1, 0, 2, 2, 3, 1, 2, 3, 0, 0, 1, 1, 0, 3, 1, 3, 1, 3, 0, 3, 3, 1, 1, 1, 3, 3, 3, 3], dtype=int32)
km.cluster_centers_
array([[ 5.49855163, -9.40880959], [ 2.79419702, 4.79694276], [ 0.348301 , -5.45307298], [-6.04098578, 5.06798706]])
X
array([[ -5.57785425, 5.87298826], [ 1.62783216, 4.17806883], [ -6.37184387, 4.41922347], [ 1.75005543, 5.44582908], [ 6.55010412, -7.9123388 ], [ -0.66982236, -5.19023657], [ 0.48085466, -5.08976945], [ -7.45962322, 4.53166747], [ 5.55912116, -10.06110303], [ -1.25569573, -5.72586023], [ -5.03188157, 4.91618824], [ 1.31006656, -5.47475738], [ 6.68288513, -10.31693051], [ 0.6769707 , -6.29133602], [ 5.69192445, -9.47641249], [ -0.07790108, -5.98485443], [ 6.2686376 , -9.38138022], [ 2.61105267, 4.22218469], [ 6.91094987, -10.6647659 ], [ -1.15296379, -5.89279504], [ -0.31748917, -6.86337766], [ 1.20634557, -3.03874201], [ 2.44078244, 4.47434875], [ -7.06349567, 5.37101341], [ 0.34789333, -3.88965912], [ 0.99265635, -5.33725682], [ 0.26308097, -5.97487434], [ 5.15516488, -8.97175683], [ -7.7498139 , 5.82291156], [ 2.60711685, 2.84436554], [ 1.35337248, -5.15783397], [ 2.37446585, -6.24342383], [ 5.4307043 , -9.75956122], [ -0.32584361, -4.65585848], [ 0.3024902 , -4.36909392], [ 6.08664442, -9.9358329 ], [ -5.09161663, 4.18830355], [ 2.62413419, 5.36941887], [ -5.79412818, 5.03331542], [ 2.98771848, 7.44372871], [ 3.78067293, 5.22062163], [ 3.38492372, 5.8943468 ], [ 2.79044036, 3.06862076], [ 4.46134719, -8.55668693], [ -5.68526509, 5.00333476], [ -5.49031464, 5.81381329], [ -0.2598064 , -6.63361828], [ 6.73488595, -9.38994773], [ 5.40050753, -9.29586681], [ 5.655043 , -9.1398234 ], [ 4.88653379, -8.87680099], [ 3.44868458, -11.32833331], [ 1.62685687, -4.83617748], [ -4.87157434, 4.63863743], [ 5.69248303, -7.19999368], [ 5.52556208, -8.18696464], [ -0.60297312, -6.82451464], [ 2.56069223, 4.6138972 ], [ 2.89022984, 2.98168388], [ -7.16921799, 3.26931456], [ 4.48697951, -10.07429823], [ 2.99232112, 5.43698055], [ 1.16464321, 5.59667831], [ -0.34268851, -5.85294901], [ 1.31468967, -5.01055177], [ 5.06766836, 5.89353659], [ 4.62182172, -9.79765865], [ 2.59184251, 4.44678157], [ 4.28981065, -9.44982413], [ -5.76456992, 4.69570432], [ -4.93861333, 5.77496677], [ -0.26686394, -5.44678194], [ 0.81677922, 4.75330395], [ 5.82662285, -9.92259335], [ -0.46888599, -5.36296292], [ 0.24318957, -7.12263784], [ -6.42972489, 6.46578798], [ 4.65804929, 6.7208918 ], [ 1.95552599, -4.05690149], [ -4.82059385, 5.15409352], [ 5.15909568, -10.13427003], [ 5.85943906, -8.38192364], [ 2.45717481, 5.96515011], [ 2.52859794, 4.5759393 ], [ 5.08727262, -9.27279108], [ -3.81369307, 5.32779566], [ 2.19087156, 5.06566526], [ -5.5048579 , 5.95458865], [ 3.49996332, 3.02156553], [ -6.44447223, 5.99238943], [ 5.99156553, -9.73238127], [ -6.75154214, 4.94975477], [ -8.33384603, 4.01468493], [ 3.80174985, 4.27826762], [ 2.31046552, 4.85417196], [ 3.71914756, 3.55752162], [ -7.11844009, 5.08754442], [ -6.74581415, 5.75727908], [ -5.0962423 , 2.23101747], [ -5.90560521, 6.41335812]])
plt.scatter(X[:, 0], X[:, 1], c=km.labels_)
plt.scatter(km.cluster_centers_[:, 0],
km.cluster_centers_[:, 1],
c='r',
marker='o',
s=200
);
Inertia is the sum of squared distance from each data point to the center of its cluster. Smaller inertia means that clusters are more tightly organized.
km.inertia_
187.49933118261387
km2 = KMeans(n_clusters=2)
km2.fit(X)
km2.inertia_
1691.3350029252008
Idea: try different numbers of clusters, check how inertia is changing.
inertia_list = []
for i in range(2, 11):
km = KMeans(n_clusters=i)
km.fit(X)
inertia_list.append(km.inertia_)
plt.plot(range(2, 11), inertia_list, 'ro-');
X, y = make_blobs(n_samples=100,
centers=3,
cluster_std=5,
random_state=10
)
plt.scatter(X[:, 0], X[:, 1], c=y);
inertia_list = []
for i in range(2, 11):
km = KMeans(n_clusters=i)
km.fit(X)
inertia_list.append(km.inertia_)
plt.plot(range(2, 11), inertia_list, 'ro-');
km
KMeans(n_clusters=10)
km.transform(X)
array([[16.75494251, 16.61493925, 21.18778656, 8.13994105, 15.68793953, 26.63742076, 25.92456949, 22.17095862, 14.02963681, 34.1182112 ], [ 2.3157333 , 8.07063441, 8.61876223, 8.72593814, 4.95748316, 10.8464402 , 9.5273581 , 12.74197984, 17.06343243, 19.4935828 ], [ 4.47937032, 10.17266727, 6.60877727, 9.10117287, 3.02683895, 12.16827611, 8.99505325, 14.90505377, 18.73362464, 18.19580115], [ 7.65586021, 2.59484061, 18.4452566 , 12.30480412, 14.59562841, 10.73229532, 16.50174014, 3.42868185, 10.8999757 , 27.40252879], [11.89859662, 17.66184516, 1.34394996, 14.68756066, 6.16526106, 16.72063105, 8.96216263, 22.04537497, 25.90500139, 13.19430766], [ 6.11260483, 11.30750146, 7.37586384, 13.65212662, 7.52868542, 8.14407907, 4.67584359, 14.2104449 , 21.28362064, 15.24959959], [ 9.86428016, 4.57317535, 20.63275151, 14.00619139, 16.79151853, 11.84027405, 18.3901423 , 2.23203731, 10.43893286, 29.29113585], [ 7.82967995, 13.04709064, 5.74093881, 8.73395242, 0.76908902, 15.87682599, 11.23056181, 18.29140105, 20.23249302, 18.61332564], [11.64468611, 16.70808545, 7.57867022, 18.67691513, 11.12229182, 10.18070802, 1.00267019, 18.81697008, 26.80431059, 9.91328417], [ 7.52405332, 12.6634751 , 7.06653246, 14.91742426, 8.27723244, 8.32761781, 3.26957316, 15.30214342, 22.68725283, 13.8842851 ], [ 7.78113931, 2.33353491, 18.53324941, 9.63722131, 13.8500999 , 13.67535713, 18.04157957, 6.55550242, 8.06993321, 28.78742321], [15.29669057, 20.81015414, 5.9318592 , 15.72993324, 8.44360482, 21.23054179, 13.51246094, 25.73925536, 27.87919788, 15.65013275], [ 1.4692998 , 7.16603228, 9.6151993 , 9.51254041, 6.40532711, 9.54279605, 9.36411625, 11.44282681, 16.70358223, 19.74864293], [ 2.61639431, 4.53452243, 12.86269636, 10.55441275, 9.63113906, 8.20081405, 11.22302902, 8.09583506, 14.81306876, 22.03530323], [ 9.33760394, 6.85535561, 18.30015916, 5.4674635 , 12.72999526, 17.95965352, 20.06206457, 12.11312354, 7.22577515, 30.12655551], [14.75262842, 19.44063474, 10.60100487, 22.09718238, 14.57228375, 11.08973338, 4.06934619, 20.77815145, 29.71752998, 8.14863178], [ 5.91584385, 5.47389159, 14.81417638, 4.26538156, 9.39524809, 15.24454821, 16.52209703, 11.46446872, 10.44910024, 26.52536985], [ 3.06619974, 5.68804587, 12.3283301 , 11.49487103, 9.64276447, 7.14213948, 10.1159814 , 8.5702647 , 16.02251859, 20.97604454], [ 6.9282631 , 1.20238025, 17.78292218, 9.93124432, 13.33326979, 12.46382779, 16.95820636, 5.87622132, 9.2686376 , 27.74954595], [ 5.92903199, 11.73041225, 4.97905887, 11.09425518, 3.74195178, 11.78186351, 7.1853025 , 15.97977401, 20.68165516, 16.13878642], [ 1.80502361, 4.7744912 , 12.10328284, 7.23482577, 7.62115417, 11.40782772, 12.56805237, 10.19049384, 13.56261755, 22.90667222], [ 8.79381739, 12.8835073 , 10.14343972, 17.03666709, 11.26940602, 5.56928323, 3.80740727, 14.21168221, 23.24446118, 14.43543179], [ 4.91393005, 10.06330505, 7.5499342 , 7.3005052 , 2.33227439, 13.77683405, 10.92498939, 15.32355732, 17.72836289, 19.77724619], [ 3.31187616, 7.35637639, 10.77523855, 11.94548124, 8.70848008, 6.88958514, 8.40206856, 10.21858758, 17.59333589, 19.23368073], [ 2.92678837, 7.66309999, 9.98460856, 11.42627445, 7.82633802, 7.63764941, 8.16319034, 10.93711445, 17.73624579, 18.88672308], [13.34934843, 18.28028862, 7.5226446 , 11.72732689, 6.32470186, 21.13581217, 15.00239495, 23.74986234, 24.16672032, 19.49779875], [ 7.74518046, 13.54659002, 3.18493548, 12.33546035, 4.1275173 , 12.81908545, 6.76096211, 17.71515594, 22.40308765, 14.70262858], [11.12518845, 16.24972467, 7.14622711, 18.10296312, 10.55708115, 10.08961491, 0.81789558, 18.49266955, 26.30666017, 10.32516588], [14.0103669 , 19.80786851, 3.38490178, 17.74013074, 9.13474285, 17.01093159, 8.19181007, 23.64807148, 28.58957119, 10.09053221], [ 5.39703077, 2.05668171, 16.12423484, 11.17136852, 12.42990077, 9.46411107, 14.33328853, 5.17841156, 12.26498639, 25.21317132], [ 6.13264407, 0.53931656, 16.98720636, 9.51190613, 12.56964739, 12.03384745, 16.23011839, 6.21820612, 9.87129305, 26.99879089], [11.08706968, 16.452585 , 4.56275556, 11.54177932, 4.06072632, 18.15482587, 11.8660662 , 21.5731658 , 23.49135135, 17.1559407 ], [ 7.19018674, 12.86468868, 4.77869721, 13.30530124, 5.86047102, 10.70652261, 4.92715744, 16.42587855, 22.34001253, 14.19236464], [11.78092505, 7.3082788 , 22.09521592, 16.98864612, 18.87045166, 10.93778769, 18.75456038, 1.32253474, 13.10815694, 29.5182774 ], [ 6.52133961, 11.17534661, 7.94512769, 6.45874147, 2.16207412, 15.64138377, 12.41917265, 16.71542401, 17.88869294, 20.64486815], [ 7.37266018, 13.17249988, 3.52722748, 11.93853655, 3.8126542 , 12.73624054, 6.99129148, 17.4076786 , 21.98955161, 15.10830596], [ 2.72394449, 3.23378716, 13.57482499, 9.35738275, 9.72264254, 9.67107708, 12.62685887, 7.86641072, 13.30836228, 23.35514039], [14.63897707, 20.42723692, 4.15477405, 18.53497969, 9.92357643, 17.22113372, 8.22734453, 24.14217463, 29.31095816, 9.31982636], [ 1.80156746, 5.27971787, 11.6851714 , 6.95461895, 7.11290129, 11.66749112, 12.42561679, 10.71606767, 13.85072464, 22.65937897], [ 0.43822395, 5.3735614 , 11.31232142, 8.51024335, 7.41511798, 10.121057 , 11.17865327, 10.15466165, 14.8125573 , 21.64223557], [12.51630044, 15.70140411, 12.42280443, 6.71885974, 7.83352275, 22.14392206, 18.56708233, 21.68223788, 19.01411847, 25.26295068], [ 3.96552034, 7.10573234, 11.72966223, 12.59702596, 9.74601072, 6.0686339 , 8.80484119, 9.39994171, 17.46559641, 19.69963648], [ 9.31130733, 11.25821436, 13.44443318, 1.87751558, 7.74223082, 19.33560135, 17.82200743, 17.27798896, 14.34486577, 26.26421826], [ 4.26672497, 4.28675467, 13.97399326, 5.6870683 , 8.90638861, 13.44578789, 15.01086574, 10.28328106, 11.34030891, 25.2406113 ], [12.68790306, 17.51237689, 8.98784536, 19.99025968, 12.56690677, 9.99400366, 1.95836665, 19.19652795, 27.72960087, 9.45396986], [ 7.76838297, 2.03559781, 18.65728104, 11.28084056, 14.42813754, 12.08521942, 17.31542227, 4.58089001, 9.50338359, 28.18042038], [ 7.31350461, 3.19353046, 17.89930501, 12.86379084, 14.37878717, 9.48150457, 15.5259591 , 3.20828412, 12.1585483 , 26.42965806], [ 3.15204971, 7.28925338, 10.32142702, 5.87668742, 5.22839569, 13.10786407, 12.36418249, 12.85024085, 14.91642531, 22.03906277], [ 8.90586059, 11.6044272 , 12.06212322, 3.05293849, 6.41839842, 18.8456153 , 16.73466452, 17.59990681, 15.56683909, 24.92394012], [13.83765433, 9.13542423, 24.17195713, 18.68170717, 20.91983614, 12.53948369, 20.68634526, 3.35131049, 13.46693515, 31.38330831], [22.44814038, 27.77248089, 14.07141858, 28.21375786, 19.69224414, 20.31945707, 12.11017813, 29.81676171, 37.68310475, 1.23104611], [ 7.35316019, 8.02599185, 15.02912826, 15.66322295, 13.49620243, 3.4767599 , 10.49166523, 7.55560757, 18.13858111, 21.1754528 ], [16.01118037, 12.49682298, 24.7119701 , 10.6610505 , 18.99566276, 23.99107032, 26.76304732, 16.19641182, 4.69147157, 36.79162869], [ 8.02045818, 10.79763399, 12.59656122, 16.66063315, 12.45138092, 2.80423591, 6.97383778, 11.24557079, 21.17231412, 17.50394 ], [ 5.23390362, 5.80630637, 13.81575068, 4.06791452, 8.39920417, 14.86153911, 15.68285653, 11.82545704, 11.45024928, 25.58792574], [13.65595254, 13.39483177, 19.9795499 , 21.8894412 , 19.46802541, 4.60125418, 13.64743411, 10.06639342, 22.47175216, 22.87761014], [11.57629061, 17.35301624, 2.08216464, 16.07406828, 7.50459152, 14.57302598, 6.15627358, 21.07057414, 26.39674704, 11.02856187], [ 1.66505126, 7.23619793, 9.49906454, 8.04232181, 5.43150822, 11.07735722, 10.40904958, 12.14246405, 16.10127207, 20.45138536], [ 8.41162097, 3.83705026, 19.02216239, 13.57032233, 15.45836092, 10.10447962, 16.54767288, 2.22024545, 11.76694425, 27.445522 ], [ 5.63841332, 1.44679112, 16.34514235, 8.34332743, 11.70282431, 12.5956293 , 16.1054639 , 7.38625499, 9.80483679, 26.76244095], [21.87769306, 27.53586019, 11.99192383, 26.388371 , 17.77274035, 21.86135049, 12.7818789 , 30.45411588, 36.89886426, 4.48553137], [12.40289419, 14.99351633, 14.78809241, 21.01163733, 16.0139168 , 3.91241251, 7.47353047, 14.32419488, 25.26562863, 16.04301261], [ 9.83082333, 11.05649737, 14.73328328, 1.33033734, 8.99674074, 19.84005811, 18.8544438 , 17.04323469, 13.22261203, 27.50887418], [ 5.94837975, 11.59650899, 5.36811347, 9.54436571, 1.93423522, 13.26071862, 9.0167029 , 16.38242474, 19.86854285, 17.49660242], [ 4.3254928 , 1.99037328, 15.14983733, 10.16093919, 11.30266812, 9.7420203 , 13.80120198, 6.33931525, 12.37649716, 24.62488478], [11.22936082, 6.65748009, 21.63259828, 16.32967713, 18.30993658, 10.89732162, 18.48133326, 0.74724869, 12.63078926, 29.28697731], [ 3.97397632, 3.22500419, 14.50459141, 10.90246805, 11.06901405, 8.53263328, 12.68629955, 6.52853655, 13.62157379, 23.54980924], [ 2.98797809, 8.76625684, 7.91866954, 9.10349155, 4.51540811, 10.94030934, 9.0144785 , 13.34356495, 17.7488928 , 18.84612396], [ 7.32842019, 6.79732995, 15.51549478, 3.20119135, 9.89727621, 16.7820244 , 17.76649555, 12.69870327, 10.05436754, 27.54926301], [ 5.18170142, 10.93191579, 5.79258321, 9.82694086, 2.93987472, 12.21449917, 8.38168715, 15.51655042, 19.58194885, 17.37840609], [20.28856186, 25.06968069, 14.28990482, 27.23995493, 19.21502728, 16.19324361, 9.50642497, 26.23332191, 35.33263349, 4.69510099], [ 9.16246193, 7.0909262 , 18.39134792, 16.32676525, 16.05179991, 6.34764025, 14.28369674, 4.07189177, 16.00053699, 24.94871507], [ 8.22535156, 14.01005685, 2.68737013, 12.15732642, 3.69402253, 13.69112097, 7.48978687, 18.36186199, 22.60567262, 14.87805343], [10.74979987, 16.5387096 , 0.25428688, 14.16388686, 5.54498893, 15.41004298, 7.87102593, 20.7961501 , 25.03623443, 13.19420005], [ 3.14671773, 5.43518 , 12.28080852, 5.66239165, 7.25916255, 12.97375675, 13.6169771 , 11.23209836, 13.00654286, 23.67375113], [ 8.39469598, 13.538931 , 5.84123625, 8.80864439, 1.37430074, 16.4767666 , 11.67544693, 18.84032321, 20.51156009, 18.76802749], [11.1942473 , 15.92692774, 8.76554561, 18.72607945, 11.65637816, 8.66541172, 1.04141849, 17.62881386, 26.16860881, 11.01265305], [15.87166919, 10.19956115, 26.58899352, 15.62188356, 21.67147293, 20.01341109, 25.86480812, 9.9302742 , 4.69147157, 36.71377508], [ 1.98698701, 3.82115295, 12.86261445, 8.47411625, 8.7787422 , 10.31841885, 12.47408917, 8.86089847, 13.43008316, 23.06710662], [ 6.13935751, 0.46009776, 17.02865386, 10.17467106, 12.80617588, 11.38768615, 15.92541062, 5.56863518, 10.37272544, 26.74611531], [ 4.2567513 , 3.52576123, 14.62492089, 11.35818657, 11.34213862, 8.11820089, 12.55278823, 6.31991525, 13.91079726, 23.43703385], [13.17580858, 18.88077685, 2.95917364, 15.15569745, 6.92669784, 18.33279942, 10.49255918, 23.44886127, 26.75555972, 13.6938363 ], [ 6.52747686, 3.68342161, 16.54385552, 6.47341071, 11.40102915, 14.65350731, 17.30107144, 9.4165567 , 8.7990449 , 27.69644991], [14.34151903, 19.81162375, 6.70049197, 20.1917446 , 11.84135326, 14.08683119, 4.84343756, 22.46506502, 29.56036272, 6.94227407], [ 7.68200432, 9.2988088 , 13.37386806, 1.11134345, 7.59244066, 17.694021 , 16.92840814, 15.31869804, 13.07904022, 25.9408323 ], [10.57381853, 16.35324913, 0.59001851, 13.84898891, 5.23345281, 15.46024016, 8.0748444 , 20.67909062, 24.77642048, 13.52971795], [11.71900826, 12.26512645, 16.69671864, 3.07386633, 11.00320802, 21.67019837, 20.93221223, 18.13154733, 12.71811233, 29.52518175], [ 3.46642667, 2.55823971, 14.24137928, 8.06903147, 9.83446702, 11.29147819, 14.00599199, 8.21425614, 11.89233714, 24.60960691], [ 2.54454597, 7.53835816, 9.82928169, 10.98341118, 7.44809779, 8.09083785, 8.38979486, 11.06216636, 17.51035294, 19.04134468], [ 2.78851561, 8.48758785, 8.23391282, 8.56192676, 4.42951521, 11.28740024, 9.56136022, 13.24694956, 17.29569854, 19.3494358 ], [ 7.55339962, 2.04501775, 18.33349453, 9.67936296, 13.70161239, 13.38212551, 17.76939742, 6.38909321, 8.36360246, 28.5254254 ], [13.94957912, 8.55343262, 24.68005111, 17.49997907, 20.86893824, 14.50625537, 22.00438253, 4.21748611, 10.88085787, 32.85008306], [ 7.77803663, 2.02233748, 18.6316379 , 10.41585676, 14.14842144, 12.97254416, 17.74719557, 5.62791675, 8.64958943, 28.55812033], [ 2.82586532, 4.95023228, 12.46666644, 6.16589665, 7.6005984 , 12.50410725, 13.47007043, 10.69441694, 12.91441036, 23.65339578], [ 5.00960027, 8.01876873, 11.24932303, 3.94650077, 5.64645658, 15.02790106, 14.13476136, 13.89569488, 14.19627285, 23.45973075], [ 7.53135963, 11.39035955, 9.61652149, 4.93041644, 3.92168107, 17.12664885, 14.32448217, 17.19856768, 16.98366065, 22.43872952], [ 0.49073567, 5.92825498, 10.76290914, 8.19089298, 6.7572908 , 10.50960704, 11.00447248, 10.81270715, 15.12246836, 21.33720063], [ 8.06167598, 4.69905553, 17.90963566, 6.69782556, 12.62350104, 16.00385092, 18.83862743, 9.92407917, 7.34641747, 29.20843045], [ 7.12215682, 4.86268216, 17.01807982, 13.99741222, 14.15882844, 7.25365578, 13.83226343, 4.02122515, 14.390401 , 24.69602647], [ 3.90108446, 2.08344485, 14.77069326, 9.63900198, 10.81115071, 10.02306928, 13.68452686, 6.87470049, 12.35397117, 24.4646711 ]])
from pathlib import Path
import requests
import numpy as np
import gzip
mnist_url = "http://yann.lecun.com/exdb/mnist/"
img_file = "train-images-idx3-ubyte.gz"
labels_file = "train-labels-idx1-ubyte.gz"
for fname in [img_file, labels_file]:
if Path(fname).is_file() :
print(f"Found: {fname}")
continue
print(f"Downloading: {fname}")
r = requests.get(mnist_url + fname)
with open(fname, 'wb') as foo:
foo.write(r.content)
with gzip.open(img_file, 'rb') as foo:
f = foo.read()
images = np.array([b for b in f[16:]]).reshape(-1, 28*28)
with gzip.open(labels_file, 'rb') as foo:
f = foo.read()
labels = np.array([b for b in f[8:]])
Found: train-images-idx3-ubyte.gz Found: train-labels-idx1-ubyte.gz
images
array([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]])
labels
array([5, 0, 4, ..., 5, 6, 8])
images = images[:5000]
labels = labels[:5000]
km = KMeans(n_clusters=2)
km.fit(images)
KMeans(n_clusters=2)
reduced = km.transform(images)
reduced
array([[1881.19641716, 1826.21606123], [2154.46745948, 1740.15736307], [1929.54455919, 2069.64256587], ..., [2041.06512567, 1827.45759362], [1302.81517557, 1820.48677329], [1764.73986247, 1867.56544165]])
sns.__version__
'0.11.2'
import seaborn as sns
plt.figure(figsize=(10,10))
sns.scatterplot(x = reduced[:, 0],
y = reduced[:, 1],
hue=labels,
palette='tab10',
s=100
);
selection = np.isin(labels, [1, 2])
selection
array([False, False, False, ..., True, True, True])
reduced[selection, 0].shape
(1051,)
reduced[selection, 1].shape
(1051,)
selection = np.isin(labels, [0, 1])
plt.figure(figsize=(10,10))
sns.scatterplot(x = reduced[selection, 0],
y = reduced[selection, 1],
hue=labels[selection],
palette='tab10',
s=100
);
import pandas as pd
planets = ["Mercury", "Venus", "Earth", "Mars", "Jupyter", "Saturn", "Uranus", "Neptune"]
diameters = [4879, 12104, 12756, 6792, 142984, 120536, 51118, 49528]
temperatures = [167, 464, 15, -65, -110, -140, -195, -200]
gravity = [3.7, 8.9, 9.8, 3.7, 23.1, 9.0, 8.7, 11.0]
s = pd.Series(diameters)
s
0 4879 1 12104 2 12756 3 6792 4 142984 5 120536 6 51118 7 49528 dtype: int64
s = pd.Series(diameters, index=planets)
s
Mercury 4879 Venus 12104 Earth 12756 Mars 6792 Jupyter 142984 Saturn 120536 Uranus 51118 Neptune 49528 dtype: int64
s['Mars']
6792
s[['Mars', 'Earth']]
Mars 6792 Earth 12756 dtype: int64
s['Earth':'Saturn']
Earth 12756 Mars 6792 Jupyter 142984 Saturn 120536 dtype: int64
s['Pluto'] = 2370
s
Mercury 4879 Venus 12104 Earth 12756 Mars 6792 Jupyter 142984 Saturn 120536 Uranus 51118 Neptune 49528 Pluto 2370 dtype: int64
s.mean()
44785.22222222222
s.max()
142984
s.min()
2370
s.argmax()
4
s.idxmax()
'Jupyter'
s/1.61
Mercury 3030.434783 Venus 7518.012422 Earth 7922.981366 Mars 4218.633540 Jupyter 88809.937888 Saturn 74867.080745 Uranus 31750.310559 Neptune 30762.732919 Pluto 1472.049689 dtype: float64
def size(x):
if x < 10000:
return "small"
else:
return "large"
s.apply(size)
0 small 1 large 2 large 3 small 4 large 5 large 6 large 7 large dtype: object
df = pd.DataFrame(
{"diameter": diameters,
"temperature": temperatures,
"gravity": gravity
}
)
df
diameter | temperature | gravity | |
---|---|---|---|
0 | 4879 | 167 | 3.7 |
1 | 12104 | 464 | 8.9 |
2 | 12756 | 15 | 9.8 |
3 | 6792 | -65 | 3.7 |
4 | 142984 | -110 | 23.1 |
5 | 120536 | -140 | 9.0 |
6 | 51118 | -195 | 8.7 |
7 | 49528 | -200 | 11.0 |
df.index
RangeIndex(start=0, stop=8, step=1)
df.index = planets
df
diameter | temperature | gravity | |
---|---|---|---|
Mercury | 4879 | 167 | 3.7 |
Venus | 12104 | 464 | 8.9 |
Earth | 12756 | 15 | 9.8 |
Mars | 6792 | -65 | 3.7 |
Jupyter | 142984 | -110 | 23.1 |
Saturn | 120536 | -140 | 9.0 |
Uranus | 51118 | -195 | 8.7 |
Neptune | 49528 | -200 | 11.0 |
df.index
Index(['Mercury', 'Venus', 'Earth', 'Mars', 'Jupyter', 'Saturn', 'Uranus', 'Neptune'], dtype='object')
df.columns
Index(['diameter', 'temperature', 'gravity'], dtype='object')
df.head(3)
diameter | temperature | gravity | |
---|---|---|---|
Mercury | 4879 | 167 | 3.7 |
Venus | 12104 | 464 | 8.9 |
Earth | 12756 | 15 | 9.8 |
df.tail(3)
diameter | temperature | gravity | |
---|---|---|---|
Saturn | 120536 | -140 | 9.0 |
Uranus | 51118 | -195 | 8.7 |
Neptune | 49528 | -200 | 11.0 |
df.sample(3)
diameter | temperature | gravity | |
---|---|---|---|
Neptune | 49528 | -200 | 11.0 |
Venus | 12104 | 464 | 8.9 |
Mercury | 4879 | 167 | 3.7 |
df['gravity']
Mercury 3.7 Venus 8.9 Earth 9.8 Mars 3.7 Jupyter 23.1 Saturn 9.0 Uranus 8.7 Neptune 11.0 Name: gravity, dtype: float64
df[['gravity', 'diameter']]
gravity | diameter | |
---|---|---|
Mercury | 3.7 | 4879 |
Venus | 8.9 | 12104 |
Earth | 9.8 | 12756 |
Mars | 3.7 | 6792 |
Jupyter | 23.1 | 142984 |
Saturn | 9.0 | 120536 |
Uranus | 8.7 | 51118 |
Neptune | 11.0 | 49528 |
df.loc['Earth', 'gravity']
9.8
df.loc[['Earth', 'Mars'], 'gravity']
Earth 9.8 Mars 3.7 Name: gravity, dtype: float64
df.loc[['Earth', 'Mars'], ['gravity', 'temperature']]
gravity | temperature | |
---|---|---|
Earth | 9.8 | 15 |
Mars | 3.7 | -65 |
df.iloc[0, 1]
167
df.iloc[0]
diameter 4879.0 temperature 167.0 gravity 3.7 Name: Mercury, dtype: float64
df[2:5]
diameter | temperature | gravity | |
---|---|---|---|
Earth | 12756 | 15 | 9.8 |
Mars | 6792 | -65 | 3.7 |
Jupyter | 142984 | -110 | 23.1 |
df.iloc[2:5, [0, 1]]
diameter | temperature | |
---|---|---|
Earth | 12756 | 15 |
Mars | 6792 | -65 |
Jupyter | 142984 | -110 |
df
diameter | temperature | gravity | |
---|---|---|---|
Mercury | 4879 | 167 | 3.7 |
Venus | 12104 | 464 | 8.9 |
Earth | 12756 | 15 | 9.8 |
Mars | 6792 | -65 | 3.7 |
Jupyter | 142984 | -110 | 23.1 |
Saturn | 120536 | -140 | 9.0 |
Uranus | 51118 | -195 | 8.7 |
Neptune | 49528 | -200 | 11.0 |
df['diameter'] > 10000
Mercury False Venus True Earth True Mars False Jupyter True Saturn True Uranus True Neptune True Name: diameter, dtype: bool
df[df['diameter'] > 10000]
diameter | temperature | gravity | |
---|---|---|---|
Venus | 12104 | 464 | 8.9 |
Earth | 12756 | 15 | 9.8 |
Jupyter | 142984 | -110 | 23.1 |
Saturn | 120536 | -140 | 9.0 |
Uranus | 51118 | -195 | 8.7 |
Neptune | 49528 | -200 | 11.0 |
df[(df['temperature'] > 0) & (df['gravity'] > 5)]
diameter | temperature | gravity | |
---|---|---|---|
Venus | 12104 | 464 | 8.9 |
Earth | 12756 | 15 | 9.8 |
df[(df['temperature'] > 0) | (df['temperature'] < -100)]
diameter | temperature | gravity | |
---|---|---|---|
Mercury | 4879 | 167 | 3.7 |
Venus | 12104 | 464 | 8.9 |
Earth | 12756 | 15 | 9.8 |
Jupyter | 142984 | -110 | 23.1 |
Saturn | 120536 | -140 | 9.0 |
Uranus | 51118 | -195 | 8.7 |
Neptune | 49528 | -200 | 11.0 |
df[~(df['temperature'] > 0)]
diameter | temperature | gravity | |
---|---|---|---|
Mars | 6792 | -65 | 3.7 |
Jupyter | 142984 | -110 | 23.1 |
Saturn | 120536 | -140 | 9.0 |
Uranus | 51118 | -195 | 8.7 |
Neptune | 49528 | -200 | 11.0 |
df
diameter | temperature | gravity | |
---|---|---|---|
Mercury | 4879 | 167 | 3.7 |
Venus | 12104 | 464 | 8.9 |
Earth | 12756 | 15 | 9.8 |
Mars | 6792 | -65 | 3.7 |
Jupyter | 142984 | -110 | 23.1 |
Saturn | 120536 | -140 | 9.0 |
Uranus | 51118 | -195 | 8.7 |
Neptune | 49528 | -200 | 11.0 |
df.sort_values(by='gravity', ascending=False)
diameter | temperature | gravity | |
---|---|---|---|
Jupyter | 142984 | -110 | 23.1 |
Neptune | 49528 | -200 | 11.0 |
Earth | 12756 | 15 | 9.8 |
Saturn | 120536 | -140 | 9.0 |
Venus | 12104 | 464 | 8.9 |
Uranus | 51118 | -195 | 8.7 |
Mercury | 4879 | 167 | 3.7 |
Mars | 6792 | -65 | 3.7 |
df.sort_index()
diameter | temperature | gravity | |
---|---|---|---|
Earth | 12756 | 15 | 9.8 |
Jupyter | 142984 | -110 | 23.1 |
Mars | 6792 | -65 | 3.7 |
Mercury | 4879 | 167 | 3.7 |
Neptune | 49528 | -200 | 11.0 |
Saturn | 120536 | -140 | 9.0 |
Uranus | 51118 | -195 | 8.7 |
Venus | 12104 | 464 | 8.9 |
df['temp_F'] = df['temperature']*1.8 + 32
df
diameter | temperature | gravity | temp_F | |
---|---|---|---|---|
Mercury | 4879 | 167 | 3.7 | 332.6 |
Venus | 12104 | 464 | 8.9 | 867.2 |
Earth | 12756 | 15 | 9.8 | 59.0 |
Mars | 6792 | -65 | 3.7 | -85.0 |
Jupyter | 142984 | -110 | 23.1 | -166.0 |
Saturn | 120536 | -140 | 9.0 | -220.0 |
Uranus | 51118 | -195 | 8.7 | -319.0 |
Neptune | 49528 | -200 | 11.0 | -328.0 |
df.columns
Index(['diameter', 'temperature', 'gravity', 'temp_F'], dtype='object')
df1 = df[['diameter', 'temperature', 'temp_F', 'gravity']]
df1
diameter | temperature | temp_F | gravity | |
---|---|---|---|---|
Mercury | 4879 | 167 | 332.6 | 3.7 |
Venus | 12104 | 464 | 867.2 | 8.9 |
Earth | 12756 | 15 | 59.0 | 9.8 |
Mars | 6792 | -65 | -85.0 | 3.7 |
Jupyter | 142984 | -110 | -166.0 | 23.1 |
Saturn | 120536 | -140 | -220.0 | 9.0 |
Uranus | 51118 | -195 | -319.0 | 8.7 |
Neptune | 49528 | -200 | -328.0 | 11.0 |
df
diameter | temperature | gravity | temp_F | |
---|---|---|---|---|
Mercury | 4879 | 167 | 3.7 | 332.6 |
Venus | 12104 | 464 | 8.9 | 867.2 |
Earth | 12756 | 15 | 9.8 | 59.0 |
Mars | 6792 | -65 | 3.7 | -85.0 |
Jupyter | 142984 | -110 | 23.1 | -166.0 |
Saturn | 120536 | -140 | 9.0 | -220.0 |
Uranus | 51118 | -195 | 8.7 | -319.0 |
Neptune | 49528 | -200 | 11.0 | -328.0 |
df.mean()
diameter 50087.1250 temperature -8.0000 gravity 9.7375 temp_F 17.6000 dtype: float64
df.min()
diameter 4879.0 temperature -200.0 gravity 3.7 temp_F -328.0 dtype: float64
df['gravity'].idxmin()
'Mercury'
df.loc[df['gravity'].idxmin()]
diameter 4879.0 temperature 167.0 gravity 3.7 temp_F 332.6 Name: Mercury, dtype: float64
df['gravity'].argmin()
0
df.iloc[df['gravity'].argmin()]
diameter 4879.0 temperature 167.0 gravity 3.7 temp_F 332.6 Name: Mercury, dtype: float64
def habitable(p):
if (-100 < p['temperature'] < 50) and p['gravity'] < 12:
return "Yes"
else:
return "No"
df.apply(habitable, axis=1)
Mercury No Venus No Earth Yes Mars Yes Jupyter No Saturn No Uranus No Neptune No dtype: object
df['habitable'] = df.apply(habitable, axis=1)
df
diameter | temperature | gravity | temp_F | habitable | |
---|---|---|---|---|---|
Mercury | 4879 | 167 | 3.7 | 332.6 | No |
Venus | 12104 | 464 | 8.9 | 867.2 | No |
Earth | 12756 | 15 | 9.8 | 59.0 | Yes |
Mars | 6792 | -65 | 3.7 | -85.0 | Yes |
Jupyter | 142984 | -110 | 23.1 | -166.0 | No |
Saturn | 120536 | -140 | 9.0 | -220.0 | No |
Uranus | 51118 | -195 | 8.7 | -319.0 | No |
Neptune | 49528 | -200 | 11.0 | -328.0 | No |
import pandas as pd
import seaborn as sns
df = sns.load_dataset("titanic")
df
survived | pclass | sex | age | sibsp | parch | fare | embarked | class | who | adult_male | deck | embark_town | alive | alone | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 3 | male | 22.0 | 1 | 0 | 7.2500 | S | Third | man | True | NaN | Southampton | no | False |
1 | 1 | 1 | female | 38.0 | 1 | 0 | 71.2833 | C | First | woman | False | C | Cherbourg | yes | False |
2 | 1 | 3 | female | 26.0 | 0 | 0 | 7.9250 | S | Third | woman | False | NaN | Southampton | yes | True |
3 | 1 | 1 | female | 35.0 | 1 | 0 | 53.1000 | S | First | woman | False | C | Southampton | yes | False |
4 | 0 | 3 | male | 35.0 | 0 | 0 | 8.0500 | S | Third | man | True | NaN | Southampton | no | True |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
886 | 0 | 2 | male | 27.0 | 0 | 0 | 13.0000 | S | Second | man | True | NaN | Southampton | no | True |
887 | 1 | 1 | female | 19.0 | 0 | 0 | 30.0000 | S | First | woman | False | B | Southampton | yes | True |
888 | 0 | 3 | female | NaN | 1 | 2 | 23.4500 | S | Third | woman | False | NaN | Southampton | no | False |
889 | 1 | 1 | male | 26.0 | 0 | 0 | 30.0000 | C | First | man | True | C | Cherbourg | yes | True |
890 | 0 | 3 | male | 32.0 | 0 | 0 | 7.7500 | Q | Third | man | True | NaN | Queenstown | no | True |
891 rows × 15 columns
df['deck']
0 NaN 1 C 2 NaN 3 C 4 NaN ... 886 NaN 887 B 888 NaN 889 C 890 NaN Name: deck, Length: 891, dtype: category Categories (7, object): ['A', 'B', 'C', 'D', 'E', 'F', 'G']
len(df['deck'])
891
df['deck'].count()
203