From 9f96be84fc0c16b4327288263dff814d9717d4c0 Mon Sep 17 00:00:00 2001 From: jtsauer <jtsauer@zedat.fu-berlin.de> Date: Thu, 13 Jul 2017 23:33:39 +0200 Subject: [PATCH] basic implementation of svm classifier with toy example --- .../SVM_Predictor-checkpoint.ipynb | 124 ++++++++++++++++++ SVM_Predictor.ipynb | 124 ++++++++++++++++++ 2 files changed, 248 insertions(+) diff --git a/.ipynb_checkpoints/SVM_Predictor-checkpoint.ipynb b/.ipynb_checkpoints/SVM_Predictor-checkpoint.ipynb index cbe8341..53e8b35 100644 --- a/.ipynb_checkpoints/SVM_Predictor-checkpoint.ipynb +++ b/.ipynb_checkpoints/SVM_Predictor-checkpoint.ipynb @@ -536,6 +536,130 @@ " tol=0.001, verbose=False)" ] }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from sklearn.metrics.pairwise import linear_kernel, polynomial_kernel, rbf_kernel # there are some more..\n", + "import numpy as np\n", + "import cvxopt\n", + "\n", + "class SVM_Predictor(object):\n", + " \n", + " def __init__(self, kernel):\n", + " self.kernel = kernel\n", + " \n", + " def train(self, X, Y):\n", + " n_samples, dim = X.shape\n", + " K = self.kernel(X)\n", + " P = cvxopt.matrix(np.outer(y, y) * K)\n", + " q = cvxopt.matrix(-1 * np.ones(n_samples))\n", + " G = cvxopt.matrix(-np.eye(n_samples))\n", + " h = cvxopt.matrix(np.zeros(n_samples))\n", + " A = cvxopt.matrix(y, (1, n_samples))\n", + " b = cvxopt.matrix(0.0)\n", + " #solvers.options['show_progress'] = False\n", + " sol = cvxopt.solvers.qp(P, q, G, h, A, b)\n", + " self.alphas = np.array(sol['x'])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " pcost dcost gap pres dres\n", + " 0: -1.1154e+01 -2.0912e+01 3e+02 2e+01 2e+00\n", + " 1: -1.7080e+01 -8.7902e+00 6e+01 3e+00 4e-01\n", + " 2: -1.1243e+01 -5.9376e+00 4e+01 2e+00 2e-01\n", + " 3: -3.7199e+00 -4.0980e+00 2e+00 8e-02 8e-03\n", + " 4: -3.6777e+00 -3.7750e+00 2e-01 5e-03 6e-04\n", + " 5: -3.7366e+00 -3.7654e+00 3e-02 2e-04 3e-05\n", + " 6: -3.7626e+00 -3.7629e+00 4e-04 3e-06 3e-07\n", + " 7: -3.7629e+00 -3.7629e+00 4e-06 3e-08 3e-09\n", + "Optimal solution found.\n", + "x (100, 2) y (100,)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XtwXNWdJ/Dvz4pkWzaZVGTVxsGWZGOBbRyWCSomNQHj\nJ2VpjE0yOzVkvLuQyZSwjdcZMGSzS1WqQpVqd5MYYnvWgJjJlIlVSag8wBA3YOMXVMJDsIbB74ce\neEkVstkEC1kSWL/943a3Wq2+3fd239v33nO/n6ou1K2rq9MS/t6j3zn3HFFVEBGRWSYE3QAiIvIe\nw52IyEAMdyIiAzHciYgMxHAnIjIQw52IyEAMdyIiAzHciYgMxHAnIjLQZ4L6xtOmTdOGhoagvj0R\nUSS9+eab51W1ttBxgYV7Q0MDOjs7g/r2RESRJCI9To5jWYaIyEAMdyIiAzHciYgMxHAnIjIQw52I\nyEAMdyIiAzHciUzX0QE0NAATJlj/7egIukVUBoHNcyeiMujoAFpbgYEB63lPj/UcANasCa5d5Dv2\n3IlM9uCDo8GeMjBgvU5GY7gTmay3193rZAyGO5HJ6urcvU7GYLgTmaytDaiuHvtadbX1OhmtYLiL\nyCQReV1E3haRIyLy/RzH3CUifSJyOPn4B3+aS0SurFkDtLcD9fWAiPXf9nYOpsaAk9kyQwCWqGq/\niFQCeEVEEqr6atZxv1DVDd43kYhKsmYNwzyGCoa7qiqA/uTTyuRD/WwUERGVxlHNXUQqROQwgA8A\n7FHV13Ic9tci8o6I/FJEZtqcp1VEOkWks6+vr4RmExFRPo7CXVUvq+r1AGYAuFFEFmQd8iyABlW9\nDsBeADtsztOuqk2q2lRbW3AjESIiKpKr2TKq+kcABwCsyHr9gqoOJZ8+AeAGT1pHRKNKXUaAyxDE\nipPZMrUi8rnkx5MBLANwPOuY6RlPVwE45mUjiWIvtYxATw+gOrqMgNOAdvr1vAAYQ6zx0jwHiFwH\nq8xSAeti8JSqPiQiDwHoVNVdIvI/YIX6pwA+BLBOVY/bnhRAU1OTcg9VIocaGqxAzlZfD3R3e/P1\n2evQANaceE6dDBUReVNVmwoeVyjc/cJwJ3JhwgSrx51NBBgZ8ebrC10AOjqsNWl6e607XNvaGPoB\ncBruvEOVKApKXUbAydfnW4em1LIQlR3DnSjMUjXwnh6rl53JzTICTpYhyHcB4OqSkcNwJwqrzN4y\nYPWYUwHvdhkBJ8sQ5LsAcHXJyGHNnSisSh1ELcb69VboX74MVFRYF5ft24NpC+XEmjtR1OUK03yv\n5+NkimNHB7BjhxXsgPXfHTus17m6ZOQw3InCqqLC3et2Ae50MDRfXZ2rS0YOyzJEYZU9gJop+99t\nrjnqVVXAFVcAFy7kPkd2SaXU6ZZUFizLEEVdfb3z13P1uoeH7YMdGD8Yyl2bjMJwJ/JLqbfyu6lz\nFzNrJTu0WVc3CsOdyA9e3PTjps7ttnedK7RZVzcKa+5EfvB66mChW/9z1dzt1Ndz6YAIc1pzd7LN\nHhG55eVNP9nBnforABgN6NR/UxeAz38e+Ogj4JNPRs/DRcBihWUZIj8UGpx0U493euv/mjXWXwUj\nI8D588C//itLLDHGsgyRH/Itnwu4W1rXzZRIMh6nQhIFKd/gpNtFuOxuWgLy9/gz/zqYNs16cBOO\n2GDPnajc3N4slK/nbjdAW2iAVcRqAwdXI8eznruITBKR10XkbRE5IiLfz3HMRBH5hYicFpHXRKSh\nuGYTxYCbm4U6OvKHu90Aba6/DjKlLi5er8vObfpCw0lZZgjAElX99wCuB7BCRL6Sdcy3APw/VZ0D\n4BEA/8vbZhIZxM3NQg8+mL+ubnehcDMrx6t12bmhR6gUDHe19CefViYf2f+3rYa1zyoA/BLAUpF8\n3Q2iGHNzs1C+kM5396jbm5q8WJedG3qEiqMBVRGpEJHDAD4AsEdVX8s65EoA7wGAqn4K4E8Aarxs\nKJFRMqctdnfb17ztQrqiIv/Uxlx/HeTjxfox3NAjVByFu6peVtXrAcwAcKOILMg6JFcvfdzfkiLS\nKiKdItLZ19fnvrVEcdPWBlRWjn2tstJaZz3fIGj2Xwc1NdYDGF/Dr6wE+vtLr5PbXSAmTGBpJgCu\npkKq6h8BHACwIutT5wDMBAAR+QyAPwPwYY6vb1fVJlVtqq2tLarBRLGTHcZOK57ZNzWdP2/Vwn/6\n07GhL2KtHllqndzur4XLl8eek4Ou5aGqeR8AagF8LvnxZAAvA1iZdcw9AB5LfnwHgKcKnfeGG25Q\nIspj507VigpVK3bHPurrvfke9fXenr9Qm3fuVK2uHvt6dbX1OjkCoFML5KuqFp7nLiLXwRosrYDV\n039KVR8SkYeS32SXiEwC8FMAfw6rx36Hqp7Nd17OcyfKw8k8dS820PBjg45856yr416sJfJs4TBV\nfQdWaGe//r2MjwcB/I3bRhKRjULz1D//eW++j13YljLAmu+cHHQtGy4/QBRGTsLOi9q11xt0dHRY\ng7PZUufkbk9lw3AnCqNCYXfhAvDNb5Z+w5CXG3SkSknZW/vV1Iyek7s9lQ3DneIrzLM2nMxTz1yr\nHSj+hiGnc+4LsSslTZ06dt157vZUFgx3iicnt8oHGf6pEKxxeS9gZjnHr/bbnddpPd2riwnl52RK\njR8PToWkQBWaAhimKXs7d+ZuaxDtz3der6dVUk5wOBWS4U7RkwoSkdG5026J5A4iEevzYQsqu/bY\nhbdf7c933jBdEA3mNNxZlqFo8WrlwUKzNsI2Za9QDV4EuPPO0RKHX+3Pd17W00OF4U7R4tXKg4Vm\nbYRtyl5mcOaiCuzePfrcr/YXOi/r6aHBcKdo8apHWqiX2dYGVFWN/7r+/uBm1aSC025tmcyfQa72\nV1WVPuWQUxkjg+FO0eJlj7RQL1Nz3EJ/4ULwG1A4/Rlktz/X+3GLpZfIYLhnUS/+AZB/ytVzfPDB\n8fPIU4LegKKlpfDrudr/ySfetJull0hguGdpbm7G7bffjscffxy9XO8ifMrVcyz0uw/y/43M2rrd\n63bt6+kJ181a5BuGewZVRWNjIw4fPoy1a9eivr4e1157LR544AHs27cPw8PDQTeRgPL0HAuVeYJc\nC8XJuEO+9v393wPTpoXzzlzyDMM9g4hg27Zt6OrqwtGjR7F582Z88YtfxNatW7F06VLU1NSwVx8X\n+aYeBj2A6KTmnq/9w8PebM5B4eZkMrwfjyjdxHTx4kV95plndO3atVpfX6+wthDU+fPn66ZNm3Tv\n3r06NDQUdDPJa5l3XaY2oCj2pimv2+XkZqFi7myl0INXm3X4Jaqbdagqjh8/jkQigUQigUOHDmF4\neBhTpkzBsmXL0NzcjObmZtRxCVPyU0eHNTja22v12NvacpenGhpyr62ezavNP8h3TjfrcLIT00wA\nTwL4AoARAO2quiXrmEUAngHQlXzp16r6UL7zRjXcs/X392P//v3YvXs3EokEepL/kObPn58O+ptu\nugkTJ04MuKUUS4V2dErhTkiR4TTcC3btAUwH8OXkx1cAOAlgftYxiwA85+RPhdQjSmUZp0ZGRvTo\n0aO6efNmXbZsmVZVVSkAnTJliq5atUofffRR7e7uDrqZVKydO1VrakZLGTU1wZdonMhci6emRrWy\nkuu/RBj8WjgMVg99edZrDPccLl68qLt27dJ169ZpQ0NDulY/b948ve+++3Tv3r06ODgYdDPJiZ07\nVauqxteqKyujF4xeLLxGgXEa7q5q7iLSAOAQgAWq+lHG64sA/ArAOQDvA7hfVY/kO5cpZRmnVBUn\nTpxAIpHA7t27x9Tqly5dmi7h1NutHULByle7ZkmDysizmnvGCacCOAigTVV/nfW5zwIYUdV+EWkB\nsEVVG3OcoxVAKwDU1dXd0ONkoMdQqVp9amC2OxkO8+bNSwf9zTffzFp9WEyYYH/7PgcjqYw8DXcR\nqQTwHIAXVPVhB8d3A2hS1fN2x8St555PZq8+kUjg4MGD7NWHTVA9d6ezYig2vBxQFVizZX6c55gv\nYPRCcSOA3tRzu0ccau7F6u/v12effVbXr1+fs1a/Z88e1uq94rT+HETNPcybX7BuHxh4NaAK4KZk\nuLwD4HDy0QJgLYC1yWM2ADgC4G0ArwL4y0LnZbg7MzIyoseOHdOHH35Yly9fPmYGzm233abbt2/X\nrq6uoJsZXvlCyG14lnu2TNh2g0oJ80UnBpyGO29iipiPP/54TK2+q8u6tWDu3LloaWlhrT5Trjne\n1dWjC43ZlVrCMkBqV+cPusYf9p+b4TwfUPUaw710qoqTJ0+mg/7AgQPpWv2SJUvStfqGhoagmxqM\nQiEU1vBMCWuIhv3nZjiGewzl69Wngn7hwoXx6dUXCqGwhmdKob88ghL2n5vhPBtQ9evBmru/RkZG\n9Pjx4/rII4/orbfeqhMnTlQAWl1drStXroxHrb5QzToKteMwDlxG4edmMPh1h6pXD4Z7efX39+tz\nzz2n99xzj86ePTs9A2fu3Ll677336osvvmjeDBwnIRTG8IwC/twC4zTcWZaJIVXFqVOn0oudHTx4\nEENDQ6iurh5Tq581a1bQTS0d54mTYVhzJ8cGBgbG1OrPnj0LALjmmmvQ3NyMlpaWeNXqiUKM4U5F\nSfXqM2fgGNurJ4oghjt5YmBgAAcOHEiXcLJ79akZOJMmTQq4pUTxwHAnz+Xr1S9evDgd9rNnzw66\nqfHFMQbjMdzJd6lefSrsz5w5A4C9+sCUMi+eF4XIYLhT2WX26vfv389efbkVe3NRWG+WopwY7hQo\nu1791VdfnQ76W265hb16LxW7LADvOI0UhjuFSnatfnBwEJMnTx4zA4e9+hIVG9JcKyZSGO4UWgMD\nAzh48GB6Bg579R4ptrzCnnukMNwpMux69YsXL04vY8xevUPFDIyy5h4pDHeKpEuXLo2p1Z8+fRoA\ne/W+42yZyPAs3EVkJqxt9r4AYARAu6puyTpGAGyBtUPTAIC7VPWtfOdluHvD9H+Tp0+fRiKRwO7d\nu8f16lNhf9VVVwXdTKKy8TLcpwOYrqpvicgVAN4EcLuqHs04pgXAf4EV7n8BYIuq/kW+8zLcSxe3\nv6btevWNjY3pNXDYqyfT+VaWEZFnAPyTqu7JeO1xAAdU9WfJ5ycALFLVP9idh+FeuriPg6V69al5\n9ezVUxz4Eu4i0gDgEIAFqvpRxuvPAfifqvpK8vlLAP6rqtqmN8O9dJzBNurSpUs4ePBguoST3atP\n1eonT54ccEuJSuM03Ce4OOFUAL8C8I+ZwZ76dI4vGRc7ItIqIp0i0tnX1+f0Wxeto8Pq3U6YYP23\no8P3b1lWdXXOXzf9ZzF58mSsWLECW7ZswalTp3Dq1Cls3boVc+bMQXt7O5qbm1FTU4OWlhZs27Yt\nHf5ExnKyoweASgAvALjP5vOPA/hGxvMTsOr0ge3EFIedwJy+xzj8LPIZGBjQRCKhGzdu1MbGxvQu\nVI2Njbpx40ZNJBI6MDAQdDOJHIFX2+zB6pU/CeDHeY75KwCJ5LFfAfB6ofP6He6Fts80hZPdzuLy\ns3Dq9OnTum3bNm1padFJkyYpAJ00aZI2Nzfr1q1b9dSpU0E3kciW03B3MlvmJgAvA/g3WFMhAeC/\nA6hL9vwfS06F/CcAK2BNhfym5qm3A/7X3FmPHsWfhb3MWn0ikcCpU6cAAHPmzEnX6hctWsRaPYVG\n7G9iivtMkkx2P4uaGuD8+bI3J9TOnDkzZgbOpUuXMGnSpDEzcObMmRN0MynGPB9QjZq2NmvOd6bq\naut1UzgdJG1rAyorx79+4QIwbZp5g6uluOqqq7Bhwwb89re/xYULF/D888/j7rvvxpkzZ7Bx40Y0\nNjaisbERGzduRCKRwKVLl4JuMlFuTmo3fjz8rrmrOqtHR5XbQdKamtx197gNrpYis1Y/efLkdK1+\nxYoVumXLFj158mTQTaQYgFc1d79wnntp3Jad7Oruhb6OchscHBxTqz958iQAq+efuluWtXryQ+xr\n7qZzO0hqdzEo9HXkzNmzZ9NBv2/fvnStftGiRelafWNjY9DNJAMw3A3ntueeax0aJ19H7hXq1adm\n4FRnDwoROcBwN1wxi4Z1dADf/rY1kJrJ5MXGwsCuV3/LLbekSzjs1ZNTDPcYKHa5X9OXCQ6zwcFB\nHDp0KB32J06cAMBePTnHcCeKgFy9+okTJ46r1Vv3CRIx3Ikix65XP3v27HTQL168mL36mGO4lwHL\nG+Snrq6uMb36gYEBTJw4Ebfcckt6b1n26uOH4e6zuO2CFGZxuMgODg7i5ZdfTq9Xz159fDHcfca1\na8IhrhfZfL36VNhfffXV7NUbiOHusyivtGhST5cX2bG9+kQigePHjwMAZs2alZ5qyV69OZyGu9Fr\ny9jxYs2ZqK6RbtrGHSK5fw8iQbcsOGfPntXt27frypUrtbq6WgHoxIkT9dZbb9VHHnlEjx8/riMj\nI0E3k4oErzbr8OsRVLh7FW5RDcmoXpTsmPZ+vHbp0iV98cUX9d5779W5c+emd6GaNWuWrl+/Xp99\n9lnt7+8PupnkAsPdhpdhEMVVJ6PS03X6s43qRTYoXV1dun37dr3tttvG9OqXL1+uDz/8MHv1EeBZ\nuAP4CYAPALxr8/lFAP4E4HDy8T0n3ziocI9KuPklCj1dt4EdxYtsGAwODuqePXv0vvvuY68+QpyG\nu5Nt9hYC6AfwpKouyPH5RQDuV9WVBQv8GYIaUI37AFwUZpfE/XcUlO7u7vSg7EsvvZSegbNw4cL0\nDJxrrrmGM3AC5umAKoAG5O+5P+fkPJmPqNfcoyzsPd24/3UVBpm9+nnz5qV79Q0NDbpu3TrdtWsX\ne/UBgZc1dwfhfgHA2wASAK51cs6oz5aJ4veOiiiUjuKmq6tLH330UV21apVOmTJFAWhVVVW6Vn/s\n2DHW6suknOH+WQBTkx+3ADiV5zytADoBdNbV1ZXj5xAq/KvBGf6cwo29+mA5DXdHNzGJSEOy9DKu\n5p7j2G4ATap6Pt9xUb+JqRisJTtn0o1Wpuvp6RlTq//4449RVVWFhQsXptfAYa3eO+WsuX8Bo3e6\n3gigN/U83yPIskxQ3NSSTSjfmPAeyJ3BwUHdu3evbtq0SefPn89evQ/g4VTInwH4A4BPAJwD8C0A\nawGsTX5+A4AjsGrurwL4SyffOI7h7rSWbEJZwoT3QKXr7u7OWatftmyZbt68WY8ePcpavUtOw51r\ny5SR02mIJpRvTHgP5K2hoSG88sor6RLO0aNHAQD19fXpNXCWLFmCKVOmBNzScOPCYSHlpJYc5UXJ\nUkx4D+Svnp4ePP/880gkEti7d++YWn1qXv3cuXNZq8/CcI+wcvZ6/Rq4ZM+d3BgeHk736nfv3j2u\nV9/c3IwlS5Zg6tSpAbc0eFwVMsJKrVeHYV0W1typFD09PfrYY4/p6tWrderUqela/dKlS/VHP/pR\nrGv14MJh0VbsTBM3oer3zUKcLUNeGBoa0pdeeknvv/9+vfbaa9MzcOrr63Xt2rX6zDPP6MWLF4Nu\nZtk4DXeWZSLATenETTmEdXGKot7e3jG1+v7+flRVVeHmm29Ol3DmzZtnbK3e+LJMXHqFbssbbubS\ne9Vzj8vvgsInX6/+7rvv1qefftq4Xj1MLsvEqZ7rNoDdHO/Fz9H03wUvXNHS09Ojjz/+uN5+++05\na/VHjhyJfK3e6HCP08JSbldILPda6OX8XZQ7aE2/cJluaGhI9+3bpw888IAuWLAg3auvq6uLdK/e\n6HCP05KwxYRnOUOwXL+LIII2Tp2IOOjt7R3Xq6+srNQlS5boD3/4Q3333Xcj0as3Otzj9I/OSagF\nWToo1+8iiN95nDoRcRPlXr3R4R63P5fzhXfQP4tiv7/bC1IQQRunTkTc9fb2ant7eyR69UaHuyoH\nulLCEEBufxfFXBCCeJ9BXzgpGPl69a2trfr000/rRx99FFj7jA93skSxdFDsOEIQQctOBKV69V/7\n2tf0iiuuCLxXz3CPiSjOVS/2gsSgpaANDQ3p/v379Tvf+c6YXv3MmTO1tbVVf/Ob3/jeq2e4x0TU\n5qrv3KlaUeHNBYkoaHa9+sWLF+sPfvADX3r1DPcYKWUdGruevx9hm+siwlo2mSKzV/+lL33Jt149\nwz1msgN+3br8gZ8vaP2q29tdSCoqGOxknvfee0+feOIJ/frXvz6uV//UU08VfV6n4V5w4TAR+QmA\nlQA+0BwbZIu1Os8WAC0ABgDcpapvFVrThguHeSfXDk/Zsnd8sltgLJPXa69zoTKKq+HhYfzud79L\n70J15513YtOmTUWdy7OFwwAsBPBl2G+Q3QIgAUAAfAXAa06uKib03MMywJevtGJXZil0rB9lkjBM\n2yQKg8uXLxf9tXDYc5/gIPwPAfgwzyGrATyZ/L6vAviciEwveFWJuFRvuafHiqieHut5R0f529Lb\n6/64igr74+rrx+/r6oW2NusviEzV1dbrRHEyYULB6C39e3hwjisBvJfx/FzytXFEpFVEOkWks6+v\nz4NvHZwHHxxfBhkYsF4vt7o698ddvmx/XHe398EOWOdsb7cuHiL+XUSIyJtwz7Uifs5Cvqq2q2qT\nqjbV1tZ68K2DY9dbdtqL9lKuHnG27B5yfX3u4+xe98qaNdbFY2TEv4sIEXkT7ucAzMx4PgPA+x6c\nN9TsestOe9FeSvWIa2pyf37CBODOO8cGaa4LQlUV0N9vHd/QEEyJiYi84UW47wLwn8XyFQB/UtU/\neHDeUAtb/XjNGsBuY/iREWDHjtGwTm3bNzAwWnuvqbHGDi5cCH4MgYhKVzDcReRnAH4P4BoROSci\n3xKRtSKyNnnIbgBnAZwG8ASA9b61NkTCWD/OVxJKjQdkDgQDVu09dZH65JPcX5Opo8Pq1bN3TxRu\n3CDbIIXmrotYZaNC89uzvyY1Bz3XfPrs+fNE5C+n89z9n49DZVNoYLWuzv2Ab+YYQphmCBFRfgx3\ng+QbWE2NB9gN+NbUFB5DCNMMISLKj+EeQqXUtdesAc6fB3buzD0eYDcQvGVL4TGEMM0QIqICnNzG\n6sfDhOUH/ODl8rt2yyOUsookdyYiCha4KmS4OA1ULzff8COIS1lPJyxr8RBFGcM9RNwEbTG7FOUK\nzbAt0sVeP5E3nIY7p0KWgd0UxVxL6ro5FrCfnmi3/G9Qy+u6fV9ElBunQhbBrxt03MwycXvnq930\nRLtVH4Ma/ORMG6LyYrgn+bmEr5tZJm7vfLULx8w7T1NEgJYW5+12wukFkTNtiMrMSe3Gj0fYau5+\n1qj9rDfna/e6deNr+F7Wud28L9bcibwBDqi6U8xApht+zRTJF5p+D6q6PT9nyxCVzmm4c0A1KcoD\nfqlVHnt7rTJHW5tVxvF7z1LuiUpUfhxQdSlsS/i6YbcBht91btbRicKL4Z4UxiV8S+X3BSvKF0Qi\n0zHcM5i2BZzfFywTL4hEpmDNnYgoQjytuYvIChE5ISKnReS7OT5/l4j0icjh5OMfimk0ERF5w8k2\nexUA/jeAZgDzAXxDRObnOPQXqnp98vHPHreTsnC7OyLKx0nP/UYAp1X1rKoOA/g5gNX+Nivcgg5W\nP++mJSIzOAn3KwG8l/H8XPK1bH8tIu+IyC9FZGauE4lIq4h0ikhnX19fEc0NXhiCldvdEVEhTsJd\ncryWPQr7LIAGVb0OwF4AO3KdSFXbVbVJVZtqa2vdtTQkwhCsXISLiApxEu7nAGT2xGcAeD/zAFW9\noKpDyadPALjBm+aFTxiClTcPEVEhTsL9DQCNIjJLRKoA3AFgV+YBIjI94+kqAMe8a2K4hCFYo3Dz\nUNDjEkRxVzDcVfVTABsAvAArtJ9S1SMi8pCIrEoetlFEjojI2wA2ArjLrwYHLQzBGvabh8IwLkEU\nd7yJqQh2C3WRJcqLsBGFndObmBju5DmuFknkH64KSYEJw7gEUdwx3MlzYRiXIIo7hjt5LuwDvkRx\n8JmgG0BmWrOGYU4UJPbciYgMxHAnIjIQw52IyEAMdyIiAzHciYgMxHAnX3DhMKJgcSokeS61cFhq\n3fvUwmEAp0cSlQt77uS5MGxoQhR3DHfyXBg2NCGKO4Y7eSKzxj7B5v8qLhxGVD4Md4/EeQAxe3OO\ny5fHH8OFw4jKy1G4i8gKETkhIqdF5Ls5Pj9RRH6R/PxrItLgdUPDLO47D+WqsQNARQUXDiMKSsHN\nOkSkAsBJAMthbZb9BoBvqOrRjGPWA7hOVdeKyB0Avqaqf5vvvCZt1hH3nYe4OQdR+Xi5WceNAE6r\n6llVHQbwcwCrs45ZDWBH8uNfAlgqIuKmwVEW9wFEbs5BFD5Owv1KAO9lPD+XfC3nMckNtf8EoMaL\nBkZB3MONm3MQhY+TcM/VA8/+I9zJMRCRVhHpFJHOvr4+J+2LhLiHGzfnIAofJ+F+DsDMjOczALxv\nd4yIfAbAnwH4MPtEqtquqk2q2lRbW1tci0OI4Wa91+5uq8be3R2v904URk6WH3gDQKOIzALwfwHc\nAeDvso7ZBeBOAL8H8B8A7NNCI7WG4c5DRBQmBcNdVT8VkQ0AXgBQAeAnqnpERB4C0KmquwD8C4Cf\nishpWD32O/xsNBER5edo4TBV3Q1gd9Zr38v4eBDA33jbNCIiKhbvUCUiMhDDnYjIQAx3IiIDMdyJ\niAzEcDdEnFelJKLxuM2eAbitHRFlY8/dANzWjoiyMdwNEPdVKYloPIa7AeK+KiURjcdwN0DcV6Uk\novEY7gbgqpRElI2zZQzBVSmJKBN77kREBmK4ExEZiOFORGQghjsRkYEY7kREBmK4ExEZiOFORGQg\nUdVgvrFIH4CeQL55YdMAnA+6ET7i+4s2098fYP57LOX91atqbaGDAgv3MBORTlVtCrodfuH7izbT\n3x9g/nssx/tjWYaIyEAMdyIiAzHcc2sPugE+4/uLNtPfH2D+e/T9/bHmTkRkIPbciYgMxHDPICIr\nROSEiJwWke8G3R6vichPROQDEXk36Lb4QURmish+ETkmIkdE5NtBt8lLIjJJRF4XkbeT7+/7QbfJ\nDyJSISKJtUaJAAACZ0lEQVT/R0SeC7otXhORbhH5NxE5LCKdvn4vlmUsIlIB4CSA5QDOAXgDwDdU\n9WigDfOQiCwE0A/gSVVdEHR7vCYi0wFMV9W3ROQKAG8CuN2U36GICIApqtovIpUAXgHwbVV9NeCm\neUpE7gPQBOCzqroy6PZ4SUS6ATSpqu9z+NlzH3UjgNOqelZVhwH8HMDqgNvkKVU9BODDoNvhF1X9\ng6q+lfz4IoBjAK4MtlXeUUt/8mll8mFU70xEZgD4KwD/HHRboo7hPupKAO9lPD8Hg4IhbkSkAcCf\nA3gt2JZ4K1myOAzgAwB7VNWo9wfgxwC+A2Ak6Ib4RAG8KCJvikirn9+I4T5KcrxmVK8oLkRkKoBf\nAfhHVf0o6PZ4SVUvq+r1AGYAuFFEjCmvichKAB+o6ptBt8VHX1XVLwNoBnBPslTqC4b7qHMAZmY8\nnwHg/YDaQkVK1qJ/BaBDVX8ddHv8oqp/BHAAwIqAm+KlrwJYlaxL/xzAEhHZGWyTvKWq7yf/+wGA\n38AqB/uC4T7qDQCNIjJLRKoA3AFgV8BtIheSA47/AuCYqj4cdHu8JiK1IvK55MeTASwDcDzYVnlH\nVf+bqs5Q1QZY//72qep/DLhZnhGRKcmBfojIFAC3AvBt5hrDPUlVPwWwAcALsAbinlLVI8G2ylsi\n8jMAvwdwjYicE5FvBd0mj30VwH+C1eM7nHy0BN0oD00HsF9E3oHVGdmjqsZNFzTYvwPwioi8DeB1\nAL9V1ef9+macCklEZCD23ImIDMRwJyIyEMOdiMhADHciIgMx3ImIDMRwJyIyEMOdiMhADHciIgP9\nf7clxzX6+Kk6AAAAAElFTkSuQmCC\n", + "text/plain": [ + "<matplotlib.figure.Figure at 0x7fe285ddec50>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "dim = 2\n", + "colours = ['red', 'blue']\n", + "\n", + "# 2-D mean of ones\n", + "M1 = np.ones((dim,))\n", + "# 2-D mean of threes\n", + "M2 = 3 * np.ones((dim,))\n", + "# 2-D covariance of 0.3\n", + "C1 = np.diag(0.3 * np.ones((dim,)))\n", + "# 2-D covariance of 0.2\n", + "C2 = np.diag(0.2 * np.ones((dim,)))\n", + "\n", + "def generate_gaussian(m, c, n_samples):\n", + " return np.random.multivariate_normal(m, c, n_samples)\n", + "\n", + "def plot_data_with_labels(x, y, ax):\n", + " unique = np.unique(y)\n", + " for li in range(len(unique)):\n", + " x_sub = x[y == unique[li]]\n", + " ax.scatter(x_sub[:, 0], x_sub[:, 1], c = colours[li])\n", + "\n", + "def plot_separator(w, b, ax):\n", + " slope = -w[0] / w[1]\n", + " intercept = -b / w[1]\n", + " x = np.arange(0, 6)\n", + " ax.plot(x, x * slope + intercept, 'k-')\n", + " \n", + "n_samples = 50\n", + "\n", + "# generate 50 points from gaussian 1\n", + "x1 = generate_gaussian(M1, C1, n_samples)\n", + "# labels\n", + "y1 = np.ones((x1.shape[0],))\n", + "# generate 50 points from gaussian 2\n", + "x2 = generate_gaussian(M2, C2, n_samples)\n", + "y2 = -np.ones((x2.shape[0],))\n", + "# join\n", + "x = np.concatenate((x1, x2), axis = 0)\n", + "y = np.concatenate((y1, y2), axis = 0)\n", + "\n", + "lin_clf = SVM_Predictor(linear_kernel)\n", + "lin_clf.train(x, y)\n", + "\n", + "# get weights\n", + "w = np.sum(lin_clf.alphas * y[:, None] * x, axis = 0)\n", + "# get bias\n", + "cond = (lin_clf.alphas > 1e-4).reshape(-1)\n", + "b = y[cond] - np.dot(x[cond], w)\n", + "bias = b[0]\n", + "\n", + "print('x {} y {}'.format(x.shape, y.shape))\n", + "fig, ax = plt.subplots()\n", + "plot_separator(w, bias, ax)\n", + "plot_data_with_labels(x, y, ax)\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/SVM_Predictor.ipynb b/SVM_Predictor.ipynb index cbe8341..53e8b35 100644 --- a/SVM_Predictor.ipynb +++ b/SVM_Predictor.ipynb @@ -536,6 +536,130 @@ " tol=0.001, verbose=False)" ] }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from sklearn.metrics.pairwise import linear_kernel, polynomial_kernel, rbf_kernel # there are some more..\n", + "import numpy as np\n", + "import cvxopt\n", + "\n", + "class SVM_Predictor(object):\n", + " \n", + " def __init__(self, kernel):\n", + " self.kernel = kernel\n", + " \n", + " def train(self, X, Y):\n", + " n_samples, dim = X.shape\n", + " K = self.kernel(X)\n", + " P = cvxopt.matrix(np.outer(y, y) * K)\n", + " q = cvxopt.matrix(-1 * np.ones(n_samples))\n", + " G = cvxopt.matrix(-np.eye(n_samples))\n", + " h = cvxopt.matrix(np.zeros(n_samples))\n", + " A = cvxopt.matrix(y, (1, n_samples))\n", + " b = cvxopt.matrix(0.0)\n", + " #solvers.options['show_progress'] = False\n", + " sol = cvxopt.solvers.qp(P, q, G, h, A, b)\n", + " self.alphas = np.array(sol['x'])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " pcost dcost gap pres dres\n", + " 0: -1.1154e+01 -2.0912e+01 3e+02 2e+01 2e+00\n", + " 1: -1.7080e+01 -8.7902e+00 6e+01 3e+00 4e-01\n", + " 2: -1.1243e+01 -5.9376e+00 4e+01 2e+00 2e-01\n", + " 3: -3.7199e+00 -4.0980e+00 2e+00 8e-02 8e-03\n", + " 4: -3.6777e+00 -3.7750e+00 2e-01 5e-03 6e-04\n", + " 5: -3.7366e+00 -3.7654e+00 3e-02 2e-04 3e-05\n", + " 6: -3.7626e+00 -3.7629e+00 4e-04 3e-06 3e-07\n", + " 7: -3.7629e+00 -3.7629e+00 4e-06 3e-08 3e-09\n", + "Optimal solution found.\n", + "x (100, 2) y (100,)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XtwXNWdJ/Dvz4pkWzaZVGTVxsGWZGOBbRyWCSomNQHj\nJ2VpjE0yOzVkvLuQyZSwjdcZMGSzS1WqQpVqd5MYYnvWgJjJlIlVSag8wBA3YOMXVMJDsIbB74ce\neEkVstkEC1kSWL/943a3Wq2+3fd239v33nO/n6ou1K2rq9MS/t6j3zn3HFFVEBGRWSYE3QAiIvIe\nw52IyEAMdyIiAzHciYgMxHAnIjIQw52IyEAMdyIiAzHciYgMxHAnIjLQZ4L6xtOmTdOGhoagvj0R\nUSS9+eab51W1ttBxgYV7Q0MDOjs7g/r2RESRJCI9To5jWYaIyEAMdyIiAzHciYgMxHAnIjIQw52I\nyEAMdyIiAzHciUzX0QE0NAATJlj/7egIukVUBoHNcyeiMujoAFpbgYEB63lPj/UcANasCa5d5Dv2\n3IlM9uCDo8GeMjBgvU5GY7gTmay3193rZAyGO5HJ6urcvU7GYLgTmaytDaiuHvtadbX1OhmtYLiL\nyCQReV1E3haRIyLy/RzH3CUifSJyOPn4B3+aS0SurFkDtLcD9fWAiPXf9nYOpsaAk9kyQwCWqGq/\niFQCeEVEEqr6atZxv1DVDd43kYhKsmYNwzyGCoa7qiqA/uTTyuRD/WwUERGVxlHNXUQqROQwgA8A\n7FHV13Ic9tci8o6I/FJEZtqcp1VEOkWks6+vr4RmExFRPo7CXVUvq+r1AGYAuFFEFmQd8iyABlW9\nDsBeADtsztOuqk2q2lRbW3AjESIiKpKr2TKq+kcABwCsyHr9gqoOJZ8+AeAGT1pHRKNKXUaAyxDE\nipPZMrUi8rnkx5MBLANwPOuY6RlPVwE45mUjiWIvtYxATw+gOrqMgNOAdvr1vAAYQ6zx0jwHiFwH\nq8xSAeti8JSqPiQiDwHoVNVdIvI/YIX6pwA+BLBOVY/bnhRAU1OTcg9VIocaGqxAzlZfD3R3e/P1\n2evQANaceE6dDBUReVNVmwoeVyjc/cJwJ3JhwgSrx51NBBgZ8ebrC10AOjqsNWl6e607XNvaGPoB\ncBruvEOVKApKXUbAydfnW4em1LIQlR3DnSjMUjXwnh6rl53JzTICTpYhyHcB4OqSkcNwJwqrzN4y\nYPWYUwHvdhkBJ8sQ5LsAcHXJyGHNnSisSh1ELcb69VboX74MVFRYF5ft24NpC+XEmjtR1OUK03yv\n5+NkimNHB7BjhxXsgPXfHTus17m6ZOQw3InCqqLC3et2Ae50MDRfXZ2rS0YOyzJEYZU9gJop+99t\nrjnqVVXAFVcAFy7kPkd2SaXU6ZZUFizLEEVdfb3z13P1uoeH7YMdGD8Yyl2bjMJwJ/JLqbfyu6lz\nFzNrJTu0WVc3CsOdyA9e3PTjps7ttnedK7RZVzcKa+5EfvB66mChW/9z1dzt1Ndz6YAIc1pzd7LN\nHhG55eVNP9nBnforABgN6NR/UxeAz38e+Ogj4JNPRs/DRcBihWUZIj8UGpx0U493euv/mjXWXwUj\nI8D588C//itLLDHGsgyRH/Itnwu4W1rXzZRIMh6nQhIFKd/gpNtFuOxuWgLy9/gz/zqYNs16cBOO\n2GDPnajc3N4slK/nbjdAW2iAVcRqAwdXI8eznruITBKR10XkbRE5IiLfz3HMRBH5hYicFpHXRKSh\nuGYTxYCbm4U6OvKHu90Aba6/DjKlLi5er8vObfpCw0lZZgjAElX99wCuB7BCRL6Sdcy3APw/VZ0D\n4BEA/8vbZhIZxM3NQg8+mL+ubnehcDMrx6t12bmhR6gUDHe19CefViYf2f+3rYa1zyoA/BLAUpF8\n3Q2iGHNzs1C+kM5396jbm5q8WJedG3qEiqMBVRGpEJHDAD4AsEdVX8s65EoA7wGAqn4K4E8Aarxs\nKJFRMqctdnfb17ztQrqiIv/Uxlx/HeTjxfox3NAjVByFu6peVtXrAcwAcKOILMg6JFcvfdzfkiLS\nKiKdItLZ19fnvrVEcdPWBlRWjn2tstJaZz3fIGj2Xwc1NdYDGF/Dr6wE+vtLr5PbXSAmTGBpJgCu\npkKq6h8BHACwIutT5wDMBAAR+QyAPwPwYY6vb1fVJlVtqq2tLarBRLGTHcZOK57ZNzWdP2/Vwn/6\n07GhL2KtHllqndzur4XLl8eek4Ou5aGqeR8AagF8LvnxZAAvA1iZdcw9AB5LfnwHgKcKnfeGG25Q\nIspj507VigpVK3bHPurrvfke9fXenr9Qm3fuVK2uHvt6dbX1OjkCoFML5KuqFp7nLiLXwRosrYDV\n039KVR8SkYeS32SXiEwC8FMAfw6rx36Hqp7Nd17OcyfKw8k8dS820PBjg45856yr416sJfJs4TBV\nfQdWaGe//r2MjwcB/I3bRhKRjULz1D//eW++j13YljLAmu+cHHQtGy4/QBRGTsLOi9q11xt0dHRY\ng7PZUufkbk9lw3AnCqNCYXfhAvDNb5Z+w5CXG3SkSknZW/vV1Iyek7s9lQ3DneIrzLM2nMxTz1yr\nHSj+hiGnc+4LsSslTZ06dt157vZUFgx3iicnt8oHGf6pEKxxeS9gZjnHr/bbnddpPd2riwnl52RK\njR8PToWkQBWaAhimKXs7d+ZuaxDtz3der6dVUk5wOBWS4U7RkwoSkdG5026J5A4iEevzYQsqu/bY\nhbdf7c933jBdEA3mNNxZlqFo8WrlwUKzNsI2Za9QDV4EuPPO0RKHX+3Pd17W00OF4U7R4tXKg4Vm\nbYRtyl5mcOaiCuzePfrcr/YXOi/r6aHBcKdo8apHWqiX2dYGVFWN/7r+/uBm1aSC025tmcyfQa72\nV1WVPuWQUxkjg+FO0eJlj7RQL1Nz3EJ/4ULwG1A4/Rlktz/X+3GLpZfIYLhnUS/+AZB/ytVzfPDB\n8fPIU4LegKKlpfDrudr/ySfetJull0hguGdpbm7G7bffjscffxy9XO8ifMrVcyz0uw/y/43M2rrd\n63bt6+kJ181a5BuGewZVRWNjIw4fPoy1a9eivr4e1157LR544AHs27cPw8PDQTeRgPL0HAuVeYJc\nC8XJuEO+9v393wPTpoXzzlzyDMM9g4hg27Zt6OrqwtGjR7F582Z88YtfxNatW7F06VLU1NSwVx8X\n+aYeBj2A6KTmnq/9w8PebM5B4eZkMrwfjyjdxHTx4kV95plndO3atVpfX6+wthDU+fPn66ZNm3Tv\n3r06NDQUdDPJa5l3XaY2oCj2pimv2+XkZqFi7myl0INXm3X4Jaqbdagqjh8/jkQigUQigUOHDmF4\neBhTpkzBsmXL0NzcjObmZtRxCVPyU0eHNTja22v12NvacpenGhpyr62ezavNP8h3TjfrcLIT00wA\nTwL4AoARAO2quiXrmEUAngHQlXzp16r6UL7zRjXcs/X392P//v3YvXs3EokEepL/kObPn58O+ptu\nugkTJ04MuKUUS4V2dErhTkiR4TTcC3btAUwH8OXkx1cAOAlgftYxiwA85+RPhdQjSmUZp0ZGRvTo\n0aO6efNmXbZsmVZVVSkAnTJliq5atUofffRR7e7uDrqZVKydO1VrakZLGTU1wZdonMhci6emRrWy\nkuu/RBj8WjgMVg99edZrDPccLl68qLt27dJ169ZpQ0NDulY/b948ve+++3Tv3r06ODgYdDPJiZ07\nVauqxteqKyujF4xeLLxGgXEa7q5q7iLSAOAQgAWq+lHG64sA/ArAOQDvA7hfVY/kO5cpZRmnVBUn\nTpxAIpHA7t27x9Tqly5dmi7h1NutHULByle7ZkmDysizmnvGCacCOAigTVV/nfW5zwIYUdV+EWkB\nsEVVG3OcoxVAKwDU1dXd0ONkoMdQqVp9amC2OxkO8+bNSwf9zTffzFp9WEyYYH/7PgcjqYw8DXcR\nqQTwHIAXVPVhB8d3A2hS1fN2x8St555PZq8+kUjg4MGD7NWHTVA9d6ezYig2vBxQFVizZX6c55gv\nYPRCcSOA3tRzu0ccau7F6u/v12effVbXr1+fs1a/Z88e1uq94rT+HETNPcybX7BuHxh4NaAK4KZk\nuLwD4HDy0QJgLYC1yWM2ADgC4G0ArwL4y0LnZbg7MzIyoseOHdOHH35Yly9fPmYGzm233abbt2/X\nrq6uoJsZXvlCyG14lnu2TNh2g0oJ80UnBpyGO29iipiPP/54TK2+q8u6tWDu3LloaWlhrT5Trjne\n1dWjC43ZlVrCMkBqV+cPusYf9p+b4TwfUPUaw710qoqTJ0+mg/7AgQPpWv2SJUvStfqGhoagmxqM\nQiEU1vBMCWuIhv3nZjiGewzl69Wngn7hwoXx6dUXCqGwhmdKob88ghL2n5vhPBtQ9evBmru/RkZG\n9Pjx4/rII4/orbfeqhMnTlQAWl1drStXroxHrb5QzToKteMwDlxG4edmMPh1h6pXD4Z7efX39+tz\nzz2n99xzj86ePTs9A2fu3Ll677336osvvmjeDBwnIRTG8IwC/twC4zTcWZaJIVXFqVOn0oudHTx4\nEENDQ6iurh5Tq581a1bQTS0d54mTYVhzJ8cGBgbG1OrPnj0LALjmmmvQ3NyMlpaWeNXqiUKM4U5F\nSfXqM2fgGNurJ4oghjt5YmBgAAcOHEiXcLJ79akZOJMmTQq4pUTxwHAnz+Xr1S9evDgd9rNnzw66\nqfHFMQbjMdzJd6lefSrsz5w5A4C9+sCUMi+eF4XIYLhT2WX26vfv389efbkVe3NRWG+WopwY7hQo\nu1791VdfnQ76W265hb16LxW7LADvOI0UhjuFSnatfnBwEJMnTx4zA4e9+hIVG9JcKyZSGO4UWgMD\nAzh48GB6Bg579R4ptrzCnnukMNwpMux69YsXL04vY8xevUPFDIyy5h4pDHeKpEuXLo2p1Z8+fRoA\ne/W+42yZyPAs3EVkJqxt9r4AYARAu6puyTpGAGyBtUPTAIC7VPWtfOdluHvD9H+Tp0+fRiKRwO7d\nu8f16lNhf9VVVwXdTKKy8TLcpwOYrqpvicgVAN4EcLuqHs04pgXAf4EV7n8BYIuq/kW+8zLcSxe3\nv6btevWNjY3pNXDYqyfT+VaWEZFnAPyTqu7JeO1xAAdU9WfJ5ycALFLVP9idh+FeuriPg6V69al5\n9ezVUxz4Eu4i0gDgEIAFqvpRxuvPAfifqvpK8vlLAP6rqtqmN8O9dJzBNurSpUs4ePBguoST3atP\n1eonT54ccEuJSuM03Ce4OOFUAL8C8I+ZwZ76dI4vGRc7ItIqIp0i0tnX1+f0Wxeto8Pq3U6YYP23\no8P3b1lWdXXOXzf9ZzF58mSsWLECW7ZswalTp3Dq1Cls3boVc+bMQXt7O5qbm1FTU4OWlhZs27Yt\nHf5ExnKyoweASgAvALjP5vOPA/hGxvMTsOr0ge3EFIedwJy+xzj8LPIZGBjQRCKhGzdu1MbGxvQu\nVI2Njbpx40ZNJBI6MDAQdDOJHIFX2+zB6pU/CeDHeY75KwCJ5LFfAfB6ofP6He6Fts80hZPdzuLy\ns3Dq9OnTum3bNm1padFJkyYpAJ00aZI2Nzfr1q1b9dSpU0E3kciW03B3MlvmJgAvA/g3WFMhAeC/\nA6hL9vwfS06F/CcAK2BNhfym5qm3A/7X3FmPHsWfhb3MWn0ikcCpU6cAAHPmzEnX6hctWsRaPYVG\n7G9iivtMkkx2P4uaGuD8+bI3J9TOnDkzZgbOpUuXMGnSpDEzcObMmRN0MynGPB9QjZq2NmvOd6bq\naut1UzgdJG1rAyorx79+4QIwbZp5g6uluOqqq7Bhwwb89re/xYULF/D888/j7rvvxpkzZ7Bx40Y0\nNjaisbERGzduRCKRwKVLl4JuMlFuTmo3fjz8rrmrOqtHR5XbQdKamtx197gNrpYis1Y/efLkdK1+\nxYoVumXLFj158mTQTaQYgFc1d79wnntp3Jad7Oruhb6OchscHBxTqz958iQAq+efuluWtXryQ+xr\n7qZzO0hqdzEo9HXkzNmzZ9NBv2/fvnStftGiRelafWNjY9DNJAMw3A3ntueeax0aJ19H7hXq1adm\n4FRnDwoROcBwN1wxi4Z1dADf/rY1kJrJ5MXGwsCuV3/LLbekSzjs1ZNTDPcYKHa5X9OXCQ6zwcFB\nHDp0KB32J06cAMBePTnHcCeKgFy9+okTJ46r1Vv3CRIx3Ikix65XP3v27HTQL168mL36mGO4lwHL\nG+Snrq6uMb36gYEBTJw4Ebfcckt6b1n26uOH4e6zuO2CFGZxuMgODg7i5ZdfTq9Xz159fDHcfca1\na8IhrhfZfL36VNhfffXV7NUbiOHusyivtGhST5cX2bG9+kQigePHjwMAZs2alZ5qyV69OZyGu9Fr\ny9jxYs2ZqK6RbtrGHSK5fw8iQbcsOGfPntXt27frypUrtbq6WgHoxIkT9dZbb9VHHnlEjx8/riMj\nI0E3k4oErzbr8OsRVLh7FW5RDcmoXpTsmPZ+vHbp0iV98cUX9d5779W5c+emd6GaNWuWrl+/Xp99\n9lnt7+8PupnkAsPdhpdhEMVVJ6PS03X6s43qRTYoXV1dun37dr3tttvG9OqXL1+uDz/8MHv1EeBZ\nuAP4CYAPALxr8/lFAP4E4HDy8T0n3ziocI9KuPklCj1dt4EdxYtsGAwODuqePXv0vvvuY68+QpyG\nu5Nt9hYC6AfwpKouyPH5RQDuV9WVBQv8GYIaUI37AFwUZpfE/XcUlO7u7vSg7EsvvZSegbNw4cL0\nDJxrrrmGM3AC5umAKoAG5O+5P+fkPJmPqNfcoyzsPd24/3UVBpm9+nnz5qV79Q0NDbpu3TrdtWsX\ne/UBgZc1dwfhfgHA2wASAK51cs6oz5aJ4veOiiiUjuKmq6tLH330UV21apVOmTJFAWhVVVW6Vn/s\n2DHW6suknOH+WQBTkx+3ADiV5zytADoBdNbV1ZXj5xAq/KvBGf6cwo29+mA5DXdHNzGJSEOy9DKu\n5p7j2G4ATap6Pt9xUb+JqRisJTtn0o1Wpuvp6RlTq//4449RVVWFhQsXptfAYa3eO+WsuX8Bo3e6\n3gigN/U83yPIskxQ3NSSTSjfmPAeyJ3BwUHdu3evbtq0SefPn89evQ/g4VTInwH4A4BPAJwD8C0A\nawGsTX5+A4AjsGrurwL4SyffOI7h7rSWbEJZwoT3QKXr7u7OWatftmyZbt68WY8ePcpavUtOw51r\ny5SR02mIJpRvTHgP5K2hoSG88sor6RLO0aNHAQD19fXpNXCWLFmCKVOmBNzScOPCYSHlpJYc5UXJ\nUkx4D+Svnp4ePP/880gkEti7d++YWn1qXv3cuXNZq8/CcI+wcvZ6/Rq4ZM+d3BgeHk736nfv3j2u\nV9/c3IwlS5Zg6tSpAbc0eFwVMsJKrVeHYV0W1typFD09PfrYY4/p6tWrderUqela/dKlS/VHP/pR\nrGv14MJh0VbsTBM3oer3zUKcLUNeGBoa0pdeeknvv/9+vfbaa9MzcOrr63Xt2rX6zDPP6MWLF4Nu\nZtk4DXeWZSLATenETTmEdXGKot7e3jG1+v7+flRVVeHmm29Ol3DmzZtnbK3e+LJMXHqFbssbbubS\ne9Vzj8vvgsInX6/+7rvv1qefftq4Xj1MLsvEqZ7rNoDdHO/Fz9H03wUvXNHS09Ojjz/+uN5+++05\na/VHjhyJfK3e6HCP08JSbldILPda6OX8XZQ7aE2/cJluaGhI9+3bpw888IAuWLAg3auvq6uLdK/e\n6HCP05KwxYRnOUOwXL+LIII2Tp2IOOjt7R3Xq6+srNQlS5boD3/4Q3333Xcj0as3Otzj9I/OSagF\nWToo1+8iiN95nDoRcRPlXr3R4R63P5fzhXfQP4tiv7/bC1IQQRunTkTc9fb2ant7eyR69UaHuyoH\nulLCEEBufxfFXBCCeJ9BXzgpGPl69a2trfr000/rRx99FFj7jA93skSxdFDsOEIQQctOBKV69V/7\n2tf0iiuuCLxXz3CPiSjOVS/2gsSgpaANDQ3p/v379Tvf+c6YXv3MmTO1tbVVf/Ob3/jeq2e4x0TU\n5qrv3KlaUeHNBYkoaHa9+sWLF+sPfvADX3r1DPcYKWUdGruevx9hm+siwlo2mSKzV/+lL33Jt149\nwz1msgN+3br8gZ8vaP2q29tdSCoqGOxknvfee0+feOIJ/frXvz6uV//UU08VfV6n4V5w4TAR+QmA\nlQA+0BwbZIu1Os8WAC0ABgDcpapvFVrThguHeSfXDk/Zsnd8sltgLJPXa69zoTKKq+HhYfzud79L\n70J15513YtOmTUWdy7OFwwAsBPBl2G+Q3QIgAUAAfAXAa06uKib03MMywJevtGJXZil0rB9lkjBM\n2yQKg8uXLxf9tXDYc5/gIPwPAfgwzyGrATyZ/L6vAviciEwveFWJuFRvuafHiqieHut5R0f529Lb\n6/64igr74+rrx+/r6oW2NusviEzV1dbrRHEyYULB6C39e3hwjisBvJfx/FzytXFEpFVEOkWks6+v\nz4NvHZwHHxxfBhkYsF4vt7o698ddvmx/XHe398EOWOdsb7cuHiL+XUSIyJtwz7Uifs5Cvqq2q2qT\nqjbV1tZ68K2DY9dbdtqL9lKuHnG27B5yfX3u4+xe98qaNdbFY2TEv4sIEXkT7ucAzMx4PgPA+x6c\nN9TsestOe9FeSvWIa2pyf37CBODOO8cGaa4LQlUV0N9vHd/QEEyJiYi84UW47wLwn8XyFQB/UtU/\neHDeUAtb/XjNGsBuY/iREWDHjtGwTm3bNzAwWnuvqbHGDi5cCH4MgYhKVzDcReRnAH4P4BoROSci\n3xKRtSKyNnnIbgBnAZwG8ASA9b61NkTCWD/OVxJKjQdkDgQDVu09dZH65JPcX5Opo8Pq1bN3TxRu\n3CDbIIXmrotYZaNC89uzvyY1Bz3XfPrs+fNE5C+n89z9n49DZVNoYLWuzv2Ab+YYQphmCBFRfgx3\ng+QbWE2NB9gN+NbUFB5DCNMMISLKj+EeQqXUtdesAc6fB3buzD0eYDcQvGVL4TGEMM0QIqICnNzG\n6sfDhOUH/ODl8rt2yyOUsookdyYiCha4KmS4OA1ULzff8COIS1lPJyxr8RBFGcM9RNwEbTG7FOUK\nzbAt0sVeP5E3nIY7p0KWgd0UxVxL6ro5FrCfnmi3/G9Qy+u6fV9ElBunQhbBrxt03MwycXvnq930\nRLtVH4Ma/ORMG6LyYrgn+bmEr5tZJm7vfLULx8w7T1NEgJYW5+12wukFkTNtiMrMSe3Gj0fYau5+\n1qj9rDfna/e6deNr+F7Wud28L9bcibwBDqi6U8xApht+zRTJF5p+D6q6PT9nyxCVzmm4c0A1KcoD\nfqlVHnt7rTJHW5tVxvF7z1LuiUpUfhxQdSlsS/i6YbcBht91btbRicKL4Z4UxiV8S+X3BSvKF0Qi\n0zHcM5i2BZzfFywTL4hEpmDNnYgoQjytuYvIChE5ISKnReS7OT5/l4j0icjh5OMfimk0ERF5w8k2\nexUA/jeAZgDzAXxDRObnOPQXqnp98vHPHreTsnC7OyLKx0nP/UYAp1X1rKoOA/g5gNX+Nivcgg5W\nP++mJSIzOAn3KwG8l/H8XPK1bH8tIu+IyC9FZGauE4lIq4h0ikhnX19fEc0NXhiCldvdEVEhTsJd\ncryWPQr7LIAGVb0OwF4AO3KdSFXbVbVJVZtqa2vdtTQkwhCsXISLiApxEu7nAGT2xGcAeD/zAFW9\noKpDyadPALjBm+aFTxiClTcPEVEhTsL9DQCNIjJLRKoA3AFgV+YBIjI94+kqAMe8a2K4hCFYo3Dz\nUNDjEkRxVzDcVfVTABsAvAArtJ9S1SMi8pCIrEoetlFEjojI2wA2ArjLrwYHLQzBGvabh8IwLkEU\nd7yJqQh2C3WRJcqLsBGFndObmBju5DmuFknkH64KSYEJw7gEUdwx3MlzYRiXIIo7hjt5LuwDvkRx\n8JmgG0BmWrOGYU4UJPbciYgMxHAnIjIQw52IyEAMdyIiAzHciYgMxHAnX3DhMKJgcSokeS61cFhq\n3fvUwmEAp0cSlQt77uS5MGxoQhR3DHfyXBg2NCGKO4Y7eSKzxj7B5v8qLhxGVD4Md4/EeQAxe3OO\ny5fHH8OFw4jKy1G4i8gKETkhIqdF5Ls5Pj9RRH6R/PxrItLgdUPDLO47D+WqsQNARQUXDiMKSsHN\nOkSkAsBJAMthbZb9BoBvqOrRjGPWA7hOVdeKyB0Avqaqf5vvvCZt1hH3nYe4OQdR+Xi5WceNAE6r\n6llVHQbwcwCrs45ZDWBH8uNfAlgqIuKmwVEW9wFEbs5BFD5Owv1KAO9lPD+XfC3nMckNtf8EoMaL\nBkZB3MONm3MQhY+TcM/VA8/+I9zJMRCRVhHpFJHOvr4+J+2LhLiHGzfnIAofJ+F+DsDMjOczALxv\nd4yIfAbAnwH4MPtEqtquqk2q2lRbW1tci0OI4Wa91+5uq8be3R2v904URk6WH3gDQKOIzALwfwHc\nAeDvso7ZBeBOAL8H8B8A7NNCI7WG4c5DRBQmBcNdVT8VkQ0AXgBQAeAnqnpERB4C0KmquwD8C4Cf\nishpWD32O/xsNBER5edo4TBV3Q1gd9Zr38v4eBDA33jbNCIiKhbvUCUiMhDDnYjIQAx3IiIDMdyJ\niAzEcDdEnFelJKLxuM2eAbitHRFlY8/dANzWjoiyMdwNEPdVKYloPIa7AeK+KiURjcdwN0DcV6Uk\novEY7gbgqpRElI2zZQzBVSmJKBN77kREBmK4ExEZiOFORGQghjsRkYEY7kREBmK4ExEZiOFORGQg\nUdVgvrFIH4CeQL55YdMAnA+6ET7i+4s2098fYP57LOX91atqbaGDAgv3MBORTlVtCrodfuH7izbT\n3x9g/nssx/tjWYaIyEAMdyIiAzHcc2sPugE+4/uLNtPfH2D+e/T9/bHmTkRkIPbciYgMxHDPICIr\nROSEiJwWke8G3R6vichPROQDEXk36Lb4QURmish+ETkmIkdE5NtBt8lLIjJJRF4XkbeT7+/7QbfJ\nDyJSISKJtUaJAAACZ0lEQVT/R0SeC7otXhORbhH5NxE5LCKdvn4vlmUsIlIB4CSA5QDOAXgDwDdU\n9WigDfOQiCwE0A/gSVVdEHR7vCYi0wFMV9W3ROQKAG8CuN2U36GICIApqtovIpUAXgHwbVV9NeCm\neUpE7gPQBOCzqroy6PZ4SUS6ATSpqu9z+NlzH3UjgNOqelZVhwH8HMDqgNvkKVU9BODDoNvhF1X9\ng6q+lfz4IoBjAK4MtlXeUUt/8mll8mFU70xEZgD4KwD/HHRboo7hPupKAO9lPD8Hg4IhbkSkAcCf\nA3gt2JZ4K1myOAzgAwB7VNWo9wfgxwC+A2Ak6Ib4RAG8KCJvikirn9+I4T5KcrxmVK8oLkRkKoBf\nAfhHVf0o6PZ4SVUvq+r1AGYAuFFEjCmvichKAB+o6ptBt8VHX1XVLwNoBnBPslTqC4b7qHMAZmY8\nnwHg/YDaQkVK1qJ/BaBDVX8ddHv8oqp/BHAAwIqAm+KlrwJYlaxL/xzAEhHZGWyTvKWq7yf/+wGA\n38AqB/uC4T7qDQCNIjJLRKoA3AFgV8BtIheSA47/AuCYqj4cdHu8JiK1IvK55MeTASwDcDzYVnlH\nVf+bqs5Q1QZY//72qep/DLhZnhGRKcmBfojIFAC3AvBt5hrDPUlVPwWwAcALsAbinlLVI8G2ylsi\n8jMAvwdwjYicE5FvBd0mj30VwH+C1eM7nHy0BN0oD00HsF9E3oHVGdmjqsZNFzTYvwPwioi8DeB1\nAL9V1ef9+macCklEZCD23ImIDMRwJyIyEMOdiMhADHciIgMx3ImIDMRwJyIyEMOdiMhADHciIgP9\nf7clxzX6+Kk6AAAAAElFTkSuQmCC\n", + "text/plain": [ + "<matplotlib.figure.Figure at 0x7fe285ddec50>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "dim = 2\n", + "colours = ['red', 'blue']\n", + "\n", + "# 2-D mean of ones\n", + "M1 = np.ones((dim,))\n", + "# 2-D mean of threes\n", + "M2 = 3 * np.ones((dim,))\n", + "# 2-D covariance of 0.3\n", + "C1 = np.diag(0.3 * np.ones((dim,)))\n", + "# 2-D covariance of 0.2\n", + "C2 = np.diag(0.2 * np.ones((dim,)))\n", + "\n", + "def generate_gaussian(m, c, n_samples):\n", + " return np.random.multivariate_normal(m, c, n_samples)\n", + "\n", + "def plot_data_with_labels(x, y, ax):\n", + " unique = np.unique(y)\n", + " for li in range(len(unique)):\n", + " x_sub = x[y == unique[li]]\n", + " ax.scatter(x_sub[:, 0], x_sub[:, 1], c = colours[li])\n", + "\n", + "def plot_separator(w, b, ax):\n", + " slope = -w[0] / w[1]\n", + " intercept = -b / w[1]\n", + " x = np.arange(0, 6)\n", + " ax.plot(x, x * slope + intercept, 'k-')\n", + " \n", + "n_samples = 50\n", + "\n", + "# generate 50 points from gaussian 1\n", + "x1 = generate_gaussian(M1, C1, n_samples)\n", + "# labels\n", + "y1 = np.ones((x1.shape[0],))\n", + "# generate 50 points from gaussian 2\n", + "x2 = generate_gaussian(M2, C2, n_samples)\n", + "y2 = -np.ones((x2.shape[0],))\n", + "# join\n", + "x = np.concatenate((x1, x2), axis = 0)\n", + "y = np.concatenate((y1, y2), axis = 0)\n", + "\n", + "lin_clf = SVM_Predictor(linear_kernel)\n", + "lin_clf.train(x, y)\n", + "\n", + "# get weights\n", + "w = np.sum(lin_clf.alphas * y[:, None] * x, axis = 0)\n", + "# get bias\n", + "cond = (lin_clf.alphas > 1e-4).reshape(-1)\n", + "b = y[cond] - np.dot(x[cond], w)\n", + "bias = b[0]\n", + "\n", + "print('x {} y {}'.format(x.shape, y.shape))\n", + "fig, ax = plt.subplots()\n", + "plot_separator(w, bias, ax)\n", + "plot_data_with_labels(x, y, ax)\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "metadata": {}, -- GitLab