diff --git a/_toc.yml b/_toc.yml index daf8a5a..9eaad8d 100644 --- a/_toc.yml +++ b/_toc.yml @@ -5,6 +5,7 @@ parts: chapters: - file: notebooks/introduction/mle_coin - file: notebooks/introduction/variational + - file: notebooks/introduction/categorical_distribution - caption: Probability - Univariate Models chapters: - file: notebooks/probability/Sample_Space_and_Random_Variables diff --git a/notebooks/introduction/categorical_distribution.ipynb b/notebooks/introduction/categorical_distribution.ipynb new file mode 100644 index 0000000..734f7bf --- /dev/null +++ b/notebooks/introduction/categorical_distribution.ipynb @@ -0,0 +1,683 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "3399f1ad", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import Counter\n", + "\n", + "import jax\n", + "import jax.numpy as jnp \n", + "\n", + "import matplotlib.pyplot as plt\n", + "from ipywidgets import interact\n", + "\n", + "\n", + "try:\n", + " import distrax\n", + "except ModuleNotFoundError:\n", + " %pip install distrax\n", + " import distrax\n", + "try:\n", + " import optax\n", + "except ModuleNotFoundError:\n", + " %pip install optax\n", + " import optax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "039a2d34", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + } + ], + "source": [ + "key = jax.random.PRNGKey(1)" + ] + }, + { + "cell_type": "markdown", + "id": "53cded78", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Categorical Distribution" + ] + }, + { + "cell_type": "markdown", + "id": "4f79695c", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "As we know, Bernoulli distribution allowed us to model discrete random variable with only two states, \n", + "what if we wanted to model more than two or may be even hundreds. This is where Categorical Distribution comes into play, it’s a generalization of the Bernoulli distribution for a categorical random variable." + ] + }, + { + "cell_type": "markdown", + "id": "60c13bea", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + " There are some general criteria for a distribution to be a categorical distribution:\n", + " - The categories are discrete.\n", + " - There are two or more potential categories.\n", + " - The sum of the probabilities for all categories must sum to 1.\n" + ] + }, + { + "cell_type": "markdown", + "id": "4b62e9bb", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Lets take a finite set of label, $ x\\ \\in\\ \\{1,...,C\\} $\n", + "\n", + "PMF:\n", + "\\begin{equation}\n", + " p(X = x| \\vec{\\theta}) = \\prod_{c=1}^{C} \\theta_c ^{I(x=c)}\\ ,\\ where\\ I(x=c) = \\begin{cases}\n", + " 1, & \\text{if}\\ x==c \\\\\n", + " 0, & \\text{otherwise}\n", + " \\end{cases} \\tag{eq. 1}\n", + "\\end{equation}\n", + "In other words, $p(X = c|\\vec{\\theta}) = θ_c$." + ] + }, + { + "cell_type": "markdown", + "id": "0a0d4c79", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Lets take the famous example of rolling a fair dice:\n", + "\n", + "So, there are K = 6 categorical outcomes and the probability for each outcome is 1/6 or same because of fair dice.\n", + "\n", + "Sample space would be : { 0, 1, 2, 3, 4, 5}" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6d42c295", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "# fair dice have same probabilities for all categories\n", + "theta = [1/6]*6\n", + "\n", + "cat = distrax.Categorical(probs=theta)\n", + "n = 1000 # no_of_samples\n", + "cat_samples = cat.sample(seed = key, sample_shape=n)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "22ea76e0", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "There are 6 different category in samples: \n", + "[0 1 2 3 4 5]\n" + ] + } + ], + "source": [ + "category = jnp.unique(cat_samples)\n", + "print(\"There are {} different category in samples: \".format(len(category)))\n", + "print(category)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0461f5e2", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0, 0.5, 'No. of occurences.')" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAELCAYAAADHksFtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAapElEQVR4nO3de5RkZXnv8e8veEHReGPAOeI4qMgRSDLqiBqVYPCCaLyDcCKiwTVeIIGjiYCaIyZRUS6aeIw4yIisQ0AN4v3GAYKSRHEGEJFbAIcAZ5wZAQEVMcBz/th7pGi6a3ZPV3XVdH8/a+3Vtd93166nYPU8vd9rqgpJkqbyO6MOQJI03kwUkqS+TBSSpL5MFJKkvkwUkqS+7jfqAAZt6623rsWLF486DEnarKxatepnVbVgsro5lygWL17MypUrRx2GJG1Wklw7VZ1NT5KkvkwUkqS+TBSSpL5MFJKkvkwUkqS+TBSSpL5mNVEkWZFkXZJLeso+m+Si9lid5KK2fHGS23vqjp/NWCVJjdmeR3ES8L+BkzcUVNVrN7xOcixwS8/1V1fVktkKTpJ0X7OaKKrqO0kWT1aXJMA+wB/PZkySpP7GaWb2c4G1VfUfPWXbJ7kQuBV4T1V9d7I3JlkGLANYtGjR0AOVBmnx4V8bdQidrD7qJZ2um2vfR+PVmb0fcGrP+RpgUVU9BXg78E9JfneyN1bV8qpaWlVLFyyYdKkSSdImGotEkeR+wKuAz24oq6o7qurG9vUq4GrgSaOJUJLmr7FIFMDzgcur6voNBUkWJNmiff14YAfgmhHFJ0nz1mwPjz0V+HdgxyTXJzmwrdqXezc7AewGXNwOl/1n4C1VddOsBStJAmZ/1NN+U5S/YZKy04HThx2TJKm/cWl6kiSNKROFJKkvE4UkqS8ThSSpr3GamT0WnFUqSffmE4UkqS8ThSSpL5ue5jib0iTNlE8UkqS+TBSSpL5setJmxaY0afb5RCFJ6stEIUnqy6YnSdqI+d7k6ROFJKkvE4UkqS8ThSSpLxOFJKkvE4UkqS8ThSSpr1lNFElWJFmX5JKesiOT3JDkovbYq6fuiCRXJbkiyYtmM1ZJUmO2nyhOAvacpPwjVbWkPb4OkGQnYF9g5/Y9/5hki1mLVJIEzHKiqKrvADd1vPzlwGlVdUdV/QS4Cth1aMFJkiY1Ln0UBye5uG2aekRb9hjgup5rrm/L7iPJsiQrk6xcv379sGOVpHllHBLFJ4AnAEuANcCx071BVS2vqqVVtXTBggUDDk+S5reRJ4qqWltVd1XV3cAJ3NO8dAPw2J5Lt2vLJEmzaMaJIskJSU6cwfsX9py+EtgwIurLwL5JHphke2AH4PxNj1SStCkGsXrs8+iYcJKcCuwObJ3keuC9wO5JlgAFrAbeDFBVP07yOeBS4E7goKq6awDxSpKmYcaJoqqeOI1r95ukeMqnkap6P/D+TYlLkjQYI++jkCSNt65NRk9O8sye8wcl+UCSLyb58+GFJ0kata5PFP8I/EnP+dHAIcCWwIeS/NWgA5MkjYeuiWIX4N8Bktwf2B84tKr2BN4F/NlwwpMkjVrXRLEVcGv7+pnt+Rfa8wuAxw04LknSmOiaKH5CkyCgmetwYVXd2J5vDdw26MAkSeOh6/DY44BPJNkbeArwxp663YGLBxyXJGlMdEoUVXVikv8Ang4cXlVn9VTfBHx0CLFJksZA5wl37RLh35mk/MhBBiRJGi+dJ9wl2SbJh5KcleTKJDu35YckedbwQpQkjVLXCXe70mwc9Gqa9ZieADywrV4IvGMYwUmSRq/rE8VHgLOBJ9Es2peeuvNx5zlJmrO69lE8FXh5Vd2dJBPqbgS2GWxYkqRx0fWJ4hZgqq3jHg+sHUw4kqRx0zVRfBl4X5LH95RVkq2Bv+SeWdqSpDmma6I4jGYJj0u5Z4js8cAVwO3A/xp8aJKkcdB1wt3N7TLj+wN7AL+kmWj3KeDkqrpjeCFKkkZpOhPufkOzG90m748tSdr8dJ1HsUeSN0xR94YkzxtoVJKksdG1j+L9wLZT1G0NfGAw4UiSxk3XRLEzsHKKuguBnQYTjiRp3HRNFHcCj5yi7lFdPyzJiiTrklzSU3Z0ksuTXJzkjCQPb8sXJ7k9yUXtcXzXz5EkDU7XRHEe8FdJHtBb2J6/A/hux/ucBOw5oexMYJeq+n3gSuCInrqrq2pJe7yl42dIkgao66ind9Mki6uSfBZYQ7MY4D7Aw4ADu9ykqr6TZPGEsm/3nH4PeE3HmCRJs6DTE0VVXUyzadG/0syl+FD78zxg16q6pM/bp+PPgG/0nG+f5MIk5yZ57lRvSrIsycokK9evXz+gUCRJML15FFcA+w0rkCTvpukLOaUtWgMsqqobkzwN+GKSnavq1kliWw4sB1i6dGkNK0ZJmo86b1w0TO0cjZcCf1pVBVBVd1TVje3rVcDVNMucS5JmUecniiSvAV4FbAdsObG+qjZpT4okewLvBP6oqn7VU74AuKmq7moXI9wBuGZTPkOStOk6JYokR9Is/PdDmoUBf7MpH5bkVGB3YOsk1wPvpRnl9EDgzHari++1I5x2A/4myX8BdwNvqaqbNuVzJUmbrusTxYHAUVX1rpl8WFVN1scx6dpRVXU6cPpMPk+SNHNd+ygeCpw1zEAkSeOpa6I4jftOlJMkzQNdm57OAj7U7mh3JvDziRdU1dcHGJckaUx0TRSfbX8uBg6YpL6ALQYRkCRpvHRNFNsPNQpJ0tjquhXqtcMORJI0njrPzE7ywCRvTXJikm8n2aEtf22SJw8vREnSKHWdcPckmk7shwGraCbNPbStfi7wEuD1Q4hPkjRiXZ8o/gH4T5rO7BcB6ak7F3jOYMOSJI2Lrp3ZzwX2rqqfJ5k4umktzd4UkqQ5qOsTxa+BB01R9xgmmVchSZobuiaKM4F3JXlYT1kleSDw54CT7SRpjura9PRXNLvbXUWTNIpmNdmdgQfQLD8uSZqDum6Feh3wB8DxNB3aV9P0S3weeFpV/XRYAUqSRmujTxRJtgQ+BpxYVX8N/PXQo5IkjY2NPlFU1a+BfZlkVztJ0tzXtTP7bOB5wwxEkjSeunZmfxz4VJKtaEY4raXp0P6tqrp0wLFJksZA10Txzfbn29ujN0kElxmXpDmra6Kw2UmS5qmuy4yfO6gPTLICeCmwrqp2acseSbM50mJgNbBPVd2cJMDfA3sBvwLeUFUXDCoWSdLGderMTvLgjR3T+MyTuO/+24cDZ1XVDjTbrh7elr8Y2KE9lgGfmMbnSJIGoOuop18At23k6KSqvgPcNKH45cBn2tefAV7RU35yNb4HPDyJCxBK0izq2kfxZ0wY5QQ8gmbJ8Z2Av51hHNtW1Zr29U+BbdvXjwGu67nu+rZsTU8ZSZbRPHGwaNGiGYYiSerVtY/ipCmqPprkEzRrPg1EVVWSiUlpY+9ZDiwHWLp06bTeK0nqr/NWqH2czsx3t1u7oUmp/bmuLb8BeGzPddu1ZZKkWTKIRPF04I4Z3uPLwAHt6wOAL/WUvz6NZwK39DRRSZJmQdc9sz88SfEDgCcDewAf7fqBSU6l2XN76yTXA+8FjgI+l+RA4Fpgn/byr9MMjb2KZnjsG7t+jiRpMLp2Zu89SdmvaTqX/4K2f6CLqtpviqo9Jrm2gIO63luSNHhdO7O3H3YgkqTxNIg+CknSHNZ1ZvaKJKdNUXdqkhMGG5YkaVx0faJ4Ac0w2MmcTjPxTpI0B3VNFAu477IbG9wMbDOYcCRJ46ZrorgW2G2Kut1oRj9JkuagroniJOCwJAcleQhAkockeRvwTuBTQ4pPkjRiXedRfAh4AvAx4B+S/BLYimZ3u+VtvSRpDuo6j+Ju4E1JjqbZ7e5RwI3A2VV15RDjkySNWNcnCgCq6grgiiHFIkkaQ13nUfxFkqOmqPtgkoMHG5YkaVx07cx+G83CfJO5sq2XJM1BXRPF45g6UfwEWDyQaCRJY6drorgZ2HGKuh2BWwcTjiRp3HRNFF8Bjkzye72FSXah2U/iS5O+S5K02es66ukI4A+BC5NcCKwBFgJPAS4BDh9OeJKkUev0RFFVN9FseXoQcDXwoPbnW4FnVNXNQ4tQkjRSnedRVNWvgU+2hyRpnpjWhLskzwCeAzySZmb2eVV1/jACkySNh06JIslWwOeBPYE7aZLEo4AtknwT2LuqfjW0KCVJI9N11NOHgWcBrwW2rKqFwJbAvm25iwJK0hzVNVG8Gjisqj7fLhBIVd1dVZ+nGfG090yCSLJjkot6jluTHJrkyCQ39JTvNZPPkSRNX9c+iocB101Rdx3wuzMJol1scAlAki2AG4AzgDcCH6mqY2Zyf0nSpuv6RPFD4K1J0lvYnr+1rR+UPYCrq+raAd5TkrSJuj5RvAv4BnB5kjOAtTT7ZL+SZp2nFw8wpn2BU3vOD07yemAl8I7J5mwkWQYsA1i0aNEAQ5EkdZ1wdzbwVOBCmv6I9wP7ABcAT62qcwYRTJIHAC+jGWEF8AmanfWW0MwGP3aK+JZX1dKqWrpgwYJBhCJJak1nwt2Paf7aH6YXAxdU1dr2M9duqEhyAvDVIX++JGmCrn0Us2U/epqdkizsqXslzbpSkqRZNK2Z2cPUTup7AfDmnuIPJ1kCFLB6Qp0kaRaMTaKoql/SzPbuLdt/ROFIklrj1vQkSRozUyaKJCuSbN++3i3JQ2YvLEnSuOj3RHEAsGGs6TnATsMPR5I0bvr1UawBdk9yKRBgyyQPnupiV4+VpLmp3xPFcuAo4BaaUUfnALf1OSRJc9CUTxRV9TdJvgY8GTgZ+Dua7U8lSfNI3+GxVbUKWJVkD+DTVfWT2QlLkjQuOs2jqKo3bnid5FE0W6HeVFU3DiswSdJ46DyPIslrk1wGrAMuB9YluSzJjDYtkiSNt657Zu8HnEKz1PgHaZYZ35Zma9TTkmxRVacNLUpJ0sh0XcLj3cDyqnrLhPKTkxwPvAcwUUjSHNS16emJwOlT1J3e1kuS5qCuiWItsHSKuqVtvSRpDura9PRp4MgkWwD/zD1boe5N0+z0weGEJ0kata6J4m+A+wOHA+/rKb8dOKatlyTNQV3nUdwNvDvJMcAuwEKataAuqaqbhxifJGnEprVxUZsUvjukWCRJY8iNiyRJfZkoJEl9mSgkSX1Nq49i2JKsptnb4i7gzqpamuSRwGeBxcBqYB870CVp9ozjE8XzqmpJVW2Y4Hc4cFZV7QCc1Z5LkmbJjBNFktcl2X8QwUzh5cBn2tefAV4xxM+SJE0wiCeKTwMnDeA+0Gy5+u0kq5Isa8u2rao17euf0qxaK0maJYPoo3gCkAHcB+A5VXVDkm2AM5Nc3ltZVZWkJr6pTSrLABYtWjSgUCRJMIAniqr6z6q6dhDBVNUN7c91wBnArsDaJAsB2p/rJnnf8qpaWlVLFyxYMIhQJEmt6exwd792l7uPJTml/blPkoGMnEqyVZKHbngNvBC4BPgycEB72QHAlwbxeZKkbrrucLcN8G3g92mGqK4FngUcBPwwyQurav0MY9kWOCPJhrj+qaq+meQHwOeSHAhcC+wzw8+RJE1D16eB44BHAc+sqvM3FCZ5Os3GRccBMxr5VFXXAH8wSfmNwB4zubckadN1bXraCzisN0kAVNUPgCOAlww6MEnSeOiaKB5IM2N6MrcBDxhMOJKkcdM1UXwPOKztZP6t9vywtl6SNAd17aN4B3AOcF2Sb3PPVqgvoplDsftQopMkjVynJ4qqugjYAVgOLABeQJMojgd2qKofDitASdJodZ4DUVU/wwX5JGneGcfVYyVJY2TKJ4okZ0/jPlVVznWQpDmoX9PTjR3evxD4Q5pVXyVJc9CUiaKq9p6qLskimmGxLwV+Bnxk8KFJksbBtBb0S/JEmpnYr6NZxfUI4JNVdfsQYpMkjYGuiwLuDLwb2Bu4DjgEWFFVvxlibJKkMdB31FOSpyX5AnAx8FTgTTTzJo43SUjS/NBv1NM3aPaE+BGwb1V9ftaikiSNjX5NTy9qf24HfDzJx/vdqKq2GVhUkqSx0S9RvG/WopAkja1+w2NNFJIkl/CQJPVnopAk9WWikCT1ZaKQJPU1FokiyWOTnJPk0iQ/TnJIW35kkhuSXNQee406Vkmab6a11tMQ3Qm8o6ouSPJQYFWSM9u6j1TVMSOMTZLmtbFIFFW1BljTvr4tyWXAY0YblSQJxqTpqVeSxcBTgO+3RQcnuTjJiiSPmOI9y5KsTLJy/fr1sxWqJM0LY5UokjwEOB04tKpuBT4BPAFYQvPEcexk76uq5VW1tKqWLliwYLbClaR5YWwSRZL70ySJU6rqCwBVtbaq7qqqu4ETgF1HGaMkzUdjkSiSBDgRuKyqjuspX9hz2SuBS2Y7Nkma78aiMxt4NrA/8KMkF7Vl7wL2S7KEZk/u1cCbRxGcJM1nY5Eoquo8IJNUfX22Y5Ek3dtYND1JksaXiUKS1JeJQpLUl4lCktSXiUKS1JeJQpLUl4lCktSXiUKS1JeJQpLUl4lCktSXiUKS1JeJQpLUl4lCktSXiUKS1JeJQpLUl4lCktSXiUKS1JeJQpLUl4lCktSXiUKS1JeJQpLU12aRKJLsmeSKJFclOXzU8UjSfDL2iSLJFsDHgRcDOwH7JdlptFFJ0vwx9okC2BW4qqquqarfAKcBLx9xTJI0b6SqRh1DX0leA+xZVW9qz/cHnlFVB/dcswxY1p7uCFwx64H2tzXws1EHMUB+n/E3177TXPs+MH7f6XFVtWCyivvNdiTDUFXLgeWjjmMqSVZW1dJRxzEofp/xN9e+01z7PrB5fafNoenpBuCxPefbtWWSpFmwOSSKHwA7JNk+yQOAfYEvjzgmSZo3xr7pqaruTHIw8C1gC2BFVf14xGFN19g2i20iv8/4m2vfaa59H9iMvtPYd2ZLkkZrc2h6kiSNkIlCktSXiWKI5trSI0lWJFmX5JJRxzIISR6b5Jwklyb5cZJDRh3TTCXZMsn5SX7Yfqf3jTqmQUiyRZILk3x11LEMQpLVSX6U5KIkK0cdz8bYRzEk7dIjVwIvAK6nGb21X1VdOtLAZiDJbsAvgJOrapdRxzNTSRYCC6vqgiQPBVYBr9jM/x8F2KqqfpHk/sB5wCFV9b0RhzYjSd4OLAV+t6peOup4ZirJamBpVY3ThLsp+UQxPHNu6ZGq+g5w06jjGJSqWlNVF7SvbwMuAx4z2qhmphq/aE/v3x6b9V+DSbYDXgJ8atSxzFcmiuF5DHBdz/n1bOb/CM1lSRYDTwG+P+JQZqxtprkIWAecWVWb+3f6KPBO4O4RxzFIBXw7yap2CaKxZqLQvJfkIcDpwKFVdeuo45mpqrqrqpbQrGKwa5LNtpkwyUuBdVW1atSxDNhzquqpNKtiH9Q2644tE8XwuPTIZqBtxz8dOKWqvjDqeAapqn4OnAPsOeJQZuLZwMvaNv3TgD9O8n9GG9LMVdUN7c91wBk0TdVjy0QxPC49Mubajt8Tgcuq6rhRxzMISRYkeXj7+kE0gykuH2lQM1BVR1TVdlW1mOZ36Oyqet2Iw5qRJFu1gydIshXwQmCsRxKaKIakqu4ENiw9chnwuc1w6ZF7SXIq8O/AjkmuT3LgqGOaoWcD+9P8lXpRe+w16qBmaCFwTpKLaf5YObOq5sSQ0jlkW+C8JD8Ezge+VlXfHHFMfTk8VpLUl08UkqS+TBSSpL5MFJKkvkwUkqS+TBSSpL5MFJp3krw6ydlJfp7kjiRXJjkuyX+bxj3emWT34UUpjQ8TheaVJMcCnwOuoZlD8ULgI8AewMencat3ArsPOj5pHI39ntnSoCT5E+DtwIFVtaKn6twky2mSxmYtyYOq6vZRx6G5xScKzSf/E7hgQpIAfruQ3jcAkhzVbirzi3YG+ilJHr3h2nbdoUcB701S7bF7W/c7SQ5vN6va0Kx1QO9npfG37SZQt7YbQu3b3mdxz3VbJ/lMkhuT/CrJvyRZOuFeq5Mcm+Svk1wP3JpkryR3J9l+wrXbt+Wb9XL3mn0mCs0L7eJ/fwh0WSphG+ADNHsgHAo8Hjg7yYbfl1cCt9CsE/Ws9rigrfsY8B5gefv+M4AV7SqoGxwKvAs4HngNcDvw4Uni+CLwIuAvgdfS/L6ek+SJE677H8AfAW9rr/sW8P+AAyZc9waapce/tpHvL91bVXl4zPkDeDTNHgBvnub7tqDZR6SA3XrKfwYcOeHaJ9LsmXDAhPKTgR/03G8N8PEJ13y9/YzF7fme7fkf9VyzFbAe+GRP2er2fltOuN/fAT/hnmV60l57zKj/X3hsfodPFJpvNrq4WZIXJ/m3JLcAd9JsOgXwpI28dQ+aRHFGkvttOICzgCXt9riPpUlaE1cSnni+K80+DOf+NvCqXwJfBZ4z4dqzqurXE8pWAI/jng7357Xnn97Id5Duw85szRc3AncAi/pdlOTpNP9onwEcRdNUU8D3gC038hlb0zwx3DJF/UKaJAHNk0GviecL28+eaC3wyEnK7qWqrknyL8AbafakeCNwfm3mKxhrNEwUmheq6r+S/CtNm/97+lz6Spp/tF9bVU2bTfK4jh9zE80TyLOZfNvOddzzO7dgQt3E8zU0fSUTbct99y2f6inpU8AJSY4AXgW8Y4rrpL5setJ88lFg6cRRSPDb0Up7Ag8C/mtDkmj96ST3+g33fcI4m+aJ4mFVtXKS4zc0+6j/FJg48uhlE86/D2zTu0VmkgfTdJCft7Ev2vpCG+dpNL/rp3V8n3QvPlFo3qiqryQ5DjgxybOBLwG/AP478Baazt4TgEOTfBT4Cs1Iqcl2VLsceEmSb7b3uKKqrkhyPHBakg8DK2mSyc7Ak6rqTVV1V5KjgaOTrAf+lSZJ/F5737vbWL+V5N+AzyY5nKbp7C9pEtnRHb/vr5OcAhwEnFrN1qjS9I26N93DY7YP4NU07fa30PzFfSVwDPDotv6dNH/5/xL4v8AONM07B/fc42k0/Ra/bOt2b8tDM/z1xzR9IuuBc4HX97w3NKOS1gO3AacAb23v8/Ce6xbQjJi6mWYI7bnA0yd8l9X0GckEPL+97/NH/d/dY/M93OFOGgNJPgW8oKq69od0ve+HgX2Ax1fVZP0m0kbZ9CTNsiS70EyM+zeapqYX04xKOmyAn7EjsBPNk8r7TBKaCZ8opFnWLq2xAlhCM4nuWuCTwLE1oF/IdmjsM2iG+u5fTUe6tElMFJKkvhweK0nqy0QhSerLRCFJ6stEIUnqy0QhSerr/wMZZn87/rQnagAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# We can see how many no. of occurence of each category are in samples.\n", + "count_cat_samples = Counter(cat_samples.tolist())\n", + "_ = plt.bar(count_cat_samples.keys(), count_cat_samples.values())\n", + "plt.xlabel(\"Category\", fontsize=15)\n", + "plt.ylabel(\"No. of occurences.\", fontsize=15)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "64c83662", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0.16666667 0.16666667 0. 0. 0. ]\n" + ] + } + ], + "source": [ + "# we can find out probs by inbuilt function of distrax\n", + "print(cat.prob([1, 0, 6, 7, -1]))" + ] + }, + { + "cell_type": "markdown", + "id": "72097542", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### PMF AND CDF PLOTS" + ] + }, + { + "cell_type": "markdown", + "id": "88a32f34", + "metadata": {}, + "source": [ + "Let takes 4 different outcomes that have different probabilites i.e.\n", + "
$ Categories = \\{ a, b, c, d \\} $
\n", + "\n", + "And we know that sum of the probabilities for all categories must sum to 1.
\n", + "Therefore,\n", + "
**We can see below that we are interactively changing the categorys except d. Because its deciding by the sum of other categories.**
\n", + "
$ 1 = p(a) + p(b) + p(c) + p(d)$
\n", + "
$ p(d) = 1 - (p(a) + p(b) + p(c)) $
" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "13505266", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b72f2fc332f54088ba8fd25674551771", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=0.1, description='p_a', max=1.0, min=0.01, step=0.05), FloatSlider(val…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def draw_plots(p_a=0.1, p_b=0.1, p_c=0.1):\n", + " p_d = 1 - jnp.sum(jnp.array([p_a, p_b, p_c]))\n", + " theta = [p_a, p_b, p_c, p_d]\n", + " \n", + " cat = distrax.Categorical(probs=jnp.array(theta))\n", + " n = 100 # no_of_samples\n", + " cat_samples = cat.sample(seed = key, sample_shape=n)\n", + " \n", + " category = jnp.unique(cat_samples)\n", + " pdf = cat.prob(cat_samples)\n", + " cdf = cat.cdf(cat_samples)\n", + "\n", + " fig, ax = plt.subplots(1,2,figsize=(12,5))\n", + "\n", + " ax[0].stem(cat_samples, pdf)\n", + " ax[0].set_xlabel(\"x\", fontsize=15)\n", + " ax[0].set_ylabel(\"P(x)\", fontsize=15)\n", + " ax[0].set_ylim(0,1)\n", + " ax[0].spines['top'].set_color('none')\n", + " ax[0].spines['right'].set_color('none')\n", + " \n", + " \n", + " \n", + " \n", + " ax[1].bar(cat_samples,cdf)\n", + " \n", + " ax[1].set_xlabel(\"x\", fontsize=15)\n", + " ax[1].set_ylabel(\"CDF(x)\", fontsize=15) \n", + " ax[1].spines['top'].set_color('none')\n", + " ax[1].spines['right'].set_color('none')\n", + " \n", + " plt.show()\n", + "\n", + "interact(draw_plots,p_a=(0.01,1,0.05),p_b=(0.01,1,0.05),p_c=(0.01,1,0.05))" + ] + }, + { + "cell_type": "markdown", + "id": "d7fa8cb9", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Maximum Likelihood Estimation of Categorical distribution\n", + "\n", + "The likelihood that the sample X follows the distribution defined by the set of parameters $\\theta $ \n", + "equals the product of the likelihoods of the individual instances $x_t$.
\n", + "Let us say we have observed some data $D = [x_1,x_2...x_n ]$,\n", + "\n", + "\\begin{equation}\n", + "L(\\vec{\\theta}|D) ≡ p(D|\\vec{\\theta}) = \\prod_{t=1}^{N} p(x_t|\\vec{\\theta}) \\tag{eq. 2}\n", + "\\end{equation}" + ] + }, + { + "cell_type": "markdown", + "id": "818c08cd", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Our goal is to find the set of parameters θ that maximizes the likelihood estimation $L(θ|D)$.\n", + "This is given by :\n", + "\\begin{equation}\n", + "\\vec{\\theta}^∗ = argmax_\\theta (L(\\vec{\\theta}| D))\n", + "\\end{equation}\n", + " So,\n", + "\\begin{align} \n", + "L(\\vec{\\theta}|D) &= \\prod_{t=1}^{N} \\prod_{c=1}^{C} \\theta_c ^{I(x_t\\ =\\ c)} \\tag{from eq. 1 and eq. 2} \\\\ \\\\\n", + " l &= \\log{L(\\vec{\\theta}|D)} \\tag{eq. 3} \\\\ \\\\\n", + "\\end{align}\n", + "Maximizing the likelihood function derived above can be a complex operation, so we can simplify it by taking the natural logarithm of the equation 2.\n" + ] + }, + { + "cell_type": "markdown", + "id": "b0170fec", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "
\n", + "\\begin{equation}\n", + "l = \\sum_{t=1}^{N} \\sum_{c=1}^{C} I(x_t = c) \\log{(\\theta_c)} \\tag{eq. 4}\n", + "\\end{equation}\n", + "The logarithm of a function is also a monotonically increasing function. So maximizing the logarithm of the likelihood function, would also be equivalent to maximizing the likelihood function.
\n" + ] + }, + { + "cell_type": "markdown", + "id": "9d6b1ca0", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Programmatically Solve MLE of Categorical Distribution\n", + "\n", + "So, we will use a gradient descent algorithm for minimization of the negative likelihood function, which would also be equivalent to maximizing the likelihood function.\n" + ] + }, + { + "cell_type": "markdown", + "id": "e76b5065", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "Before we use gradient descent algorithm, we need following primary parameters,\n", + "1. Loss Function : ( we use negative of equation 4.)\n", + "2. Initial values of $\\vec{\\theta}$\n", + "3. Learning rate : which scales the gradient and controls the step size of parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "993f2c66", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "# Negative of likelihood function of categorical distribution (eq. 4.)\n", + "def loss_function(thetas, samples):\n", + " thetas = jax.nn.softmax(thetas)\n", + " result_fit = distrax.Categorical(probs= thetas)\n", + " return -jnp.sum(result_fit.log_prob(samples))/len(samples)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "3f8e8578", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "class Model():\n", + " def __init__(self, samples, thetas):\n", + " self.samples = samples\n", + " self.thetas = thetas \n", + " self.grads = [0]*len(thetas)\n", + " \n", + " def probs(self):\n", + " return jax.nn.softmax(self.thetas)\n", + "\n", + " def fit(self, optimizer, epochs):\n", + " opt_state = optimizer.init(self.thetas)\n", + " for i in range(epochs): \n", + " loss, self.grads = jax.value_and_grad(loss_function)(self.thetas, self.samples)\n", + " \n", + " updates, opt_state = optimizer.update(self.grads, opt_state, self.thetas)\n", + " self.thetas = optax.apply_updates(self.thetas, updates)\n", + " \n", + " if i%10 == 0:\n", + " print(f\"Loss at epoch {i} : \",loss)\n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "3d0bc1a6", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize default parametrs\n", + "learning_rate = 1e-4\n", + "init_thetas = jnp.array([0.5,0.5])" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d1809c3e", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "# Lets generate some samples \n", + "cat = distrax.Categorical(probs = [0.2, 0.8])\n", + "samples = cat.sample(seed = jax.random.PRNGKey(1), sample_shape = 1000)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e8bc1c44", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss at epoch 0 : 0.69314724\n", + "Loss at epoch 10 : 0.4723168\n", + "Loss at epoch 20 : 0.48604885\n", + "Loss at epoch 30 : 0.46990055\n", + "Loss at epoch 40 : 0.47188267\n", + "Loss at epoch 50 : 0.4701206\n", + "Loss at epoch 60 : 0.46998137\n", + "Loss at epoch 70 : 0.4699761\n", + "Loss at epoch 80 : 0.4698781\n", + "Loss at epoch 90 : 0.4698786\n" + ] + } + ], + "source": [ + "optimizer = optax.adam(learning_rate = 0.1)\n", + "\n", + "model = Model(samples, init_thetas)\n", + "model.fit(optimizer=optimizer, epochs = 100)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "c2b3f82f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initial : [0.5 0.5]\n", + "Learned : [0.18020472 0.8197953 ]\n" + ] + } + ], + "source": [ + "print(\"Initial : \", init_thetas)\n", + "print(\"Learned : \", model.probs())\n" + ] + }, + { + "cell_type": "markdown", + "id": "8f70d336", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Theoretical Derivation of MLE for Categorical distribution\n", + "Before we can differentiate the log-likelihood to find the maximum,
\n", + "we need to introduce the constraint that all probabilities $\\theta $ sum up to 1, that is :\n", + "\\begin{equation}\n", + "\\sum_{c=1}^{C} \\theta_c = 1 \\tag{eq. 5}\n", + "\\end{equation}\n", + "\n", + "Then we use the lagrangian function with the constraint than has the following form :\n", + "\\begin{equation}\n", + "l(\\theta, \\lambda) = \\log{(L(\\theta))} + \\lambda(1 - \\sum_{c=1}^{C} \\theta_c) \\tag{eq. 6}\n", + "\\end{equation}\n" + ] + }, + { + "cell_type": "markdown", + "id": "8b717f2c", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "To find the maximum, we differentiate the lagrangian w.r.t. $\\theta$ as follows:\n", + "\\begin{align}\n", + "\\frac{\\partial l}{\\partial \\theta_c} &= \\frac{\\sum_{t=1}^{N}I(x_t = c)}{\\theta_c} - \\lambda = 0 \\tag{from eq. 4}, \\\\\n", + " \\lambda &= \\frac{\\sum_{t=1}^{N}I(x_t = c)}{\\theta_c}, \\\\\n", + "\\therefore \\theta_c &= \\frac{N_c}{\\lambda} \\tag{$N_c$ = no. of occurences that belongs to $\\it{c}$ category}\\\\\n", + " &\\tag{eq. 7}\n", + "\\end{align}" + ] + }, + { + "cell_type": "markdown", + "id": "a06a7a01", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "To solve for $\\lambda$, we sum both side and make use of our initial constraint:\n", + "\\begin{align}\n", + " \\theta_c &= \\frac{N_c}{\\lambda}, \\\\\n", + " \\sum_{c=1}^{C}\\theta &= \\frac{1}{\\lambda} \\sum_{c=1}^{C}N_c, \\\\\n", + " 1 &= \\frac{1}{\\lambda} n \\tag{n = sum of all $N_c$} \\\\\n", + " \\therefore \\lambda &= n\n", + "\\end{align}" + ] + }, + { + "cell_type": "markdown", + "id": "127eb699", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Let put $\\lambda$ in eq 7. that giving us the MLE for $\\theta$:\n", + "\\begin{align}\n", + " \\theta_c &= \\frac{N_c}{n}\n", + "\\end{align}" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "3f04dac8", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Theta's values \t\t: [0.3, 0.3, 0.4]\n", + "Theta's values by MLE\t: [0.327, 0.28, 0.393]\n" + ] + } + ], + "source": [ + "# Lets find out MlE for above distribution\n", + "key = jax.random.PRNGKey(1)\n", + "\n", + "n2 = 1000\n", + "true_thetas = [0.3,0.3,0.4]\n", + "\n", + "cat2 = distrax.Categorical(probs=true_thetas)\n", + "cat_samples2 = cat2.sample(seed = key, sample_shape=n2)\n", + "category = jnp.unique(cat_samples2)\n", + "\n", + "mle_thetas = [round(float(jnp.sum(cat_samples2 == i)/n2),3) for i in category]\n", + "\n", + "\n", + "print(\"Theta's values \\t\\t:\",true_thetas)\n", + "print(\"Theta's values by MLE\\t:\",mle_thetas)" + ] + } + ], + "metadata": { + "celltoolbar": "Slideshow", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}