You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

267 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'numpy'",
"output_type": "error",
"traceback": [
"\u001b[0;31m--------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-2-d9bbc8b73862>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrcdefaults\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlines\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLine2D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'numpy'"
]
}
],
"source": [
"\n",
"\n",
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"plt.rcdefaults()\n",
"from matplotlib.lines import Line2D\n",
"from matplotlib.patches import Rectangle\n",
"from matplotlib.patches import Circle\n",
"\n",
"NumDots = 4\n",
"NumConvMax = 8\n",
"NumFcMax = 20\n",
"White = 1.\n",
"Light = 0.7\n",
"Medium = 0.5\n",
"Dark = 0.3\n",
"Darker = 0.15\n",
"Black = 0.\n",
"\n",
"\n",
"def add_layer(patches, colors, size=(24, 24), num=5,\n",
" top_left=[0, 0],\n",
" loc_diff=[3, -3],\n",
" ):\n",
" # add a rectangle\n",
" top_left = np.array(top_left)\n",
" loc_diff = np.array(loc_diff)\n",
" loc_start = top_left - np.array([0, size[0]])\n",
" for ind in range(num):\n",
" patches.append(Rectangle(loc_start + ind * loc_diff, size[1], size[0]))\n",
" if ind % 2:\n",
" colors.append(Medium)\n",
" else:\n",
" colors.append(Light)\n",
"\n",
"\n",
"def add_layer_with_omission(patches, colors, size=(24, 24),\n",
" num=5, num_max=8,\n",
" num_dots=4,\n",
" top_left=[0, 0],\n",
" loc_diff=[3, -3],\n",
" ):\n",
" # add a rectangle\n",
" top_left = np.array(top_left)\n",
" loc_diff = np.array(loc_diff)\n",
" loc_start = top_left - np.array([0, size[0]])\n",
" this_num = min(num, num_max)\n",
" start_omit = (this_num - num_dots) // 2\n",
" end_omit = this_num - start_omit\n",
" start_omit -= 1\n",
" for ind in range(this_num):\n",
" if (num > num_max) and (start_omit < ind < end_omit):\n",
" omit = True\n",
" else:\n",
" omit = False\n",
"\n",
" if omit:\n",
" patches.append(\n",
" Circle(loc_start + ind * loc_diff + np.array(size) / 2, 0.5))\n",
" else:\n",
" patches.append(Rectangle(loc_start + ind * loc_diff,\n",
" size[1], size[0]))\n",
"\n",
" if omit:\n",
" colors.append(Black)\n",
" elif ind % 2:\n",
" colors.append(Medium)\n",
" else:\n",
" colors.append(Light)\n",
"\n",
"\n",
"def add_mapping(patches, colors, start_ratio, end_ratio, patch_size, ind_bgn,\n",
" top_left_list, loc_diff_list, num_show_list, size_list):\n",
"\n",
" start_loc = top_left_list[ind_bgn] \\\n",
" + (num_show_list[ind_bgn] - 1) * np.array(loc_diff_list[ind_bgn]) \\\n",
" + np.array([start_ratio[0] * (size_list[ind_bgn][1] - patch_size[1]),\n",
" - start_ratio[1] * (size_list[ind_bgn][0] - patch_size[0])]\n",
" )\n",
"\n",
"\n",
"\n",
"\n",
" end_loc = top_left_list[ind_bgn + 1] \\\n",
" + (num_show_list[ind_bgn + 1] - 1) * np.array(\n",
" loc_diff_list[ind_bgn + 1]) \\\n",
" + np.array([end_ratio[0] * size_list[ind_bgn + 1][1],\n",
" - end_ratio[1] * size_list[ind_bgn + 1][0]])\n",
"\n",
"\n",
" patches.append(Rectangle(start_loc, patch_size[1], -patch_size[0]))\n",
" colors.append(Dark)\n",
" patches.append(Line2D([start_loc[0], end_loc[0]],\n",
" [start_loc[1], end_loc[1]]))\n",
" colors.append(Darker)\n",
" patches.append(Line2D([start_loc[0] + patch_size[1], end_loc[0]],\n",
" [start_loc[1], end_loc[1]]))\n",
" colors.append(Darker)\n",
" patches.append(Line2D([start_loc[0], end_loc[0]],\n",
" [start_loc[1] - patch_size[0], end_loc[1]]))\n",
" colors.append(Darker)\n",
" patches.append(Line2D([start_loc[0] + patch_size[1], end_loc[0]],\n",
" [start_loc[1] - patch_size[0], end_loc[1]]))\n",
" colors.append(Darker)\n",
"\n",
"\n",
"\n",
"def label(xy, text, xy_off=[0, 4]):\n",
" plt.text(xy[0] + xy_off[0], xy[1] + xy_off[1], text,\n",
" family='sans-serif', size=8)\n",
"\n",
"\n",
"if __name__ == '__main__':\n",
"\n",
" fc_unit_size = 2\n",
" layer_width = 40\n",
" flag_omit = True\n",
"\n",
" patches = []\n",
" colors = []\n",
"\n",
" fig, ax = plt.subplots()\n",
"\n",
"\n",
" ############################\n",
" # conv layers\n",
" size_list = [(32, 32), (18, 18), (10, 10), (6, 6), (4, 4)]\n",
" num_list = [3, 32, 32, 48, 48]\n",
" x_diff_list = [0, layer_width, layer_width, layer_width, layer_width]\n",
" text_list = ['Inputs'] + ['Feature\\nmaps'] * (len(size_list) - 1)\n",
" loc_diff_list = [[3, -3]] * len(size_list)\n",
"\n",
" num_show_list = list(map(min, num_list, [NumConvMax] * len(num_list)))\n",
" top_left_list = np.c_[np.cumsum(x_diff_list), np.zeros(len(x_diff_list))]\n",
"\n",
" for ind in range(len(size_list)-1,-1,-1):\n",
" if flag_omit:\n",
" add_layer_with_omission(patches, colors, size=size_list[ind],\n",
" num=num_list[ind],\n",
" num_max=NumConvMax,\n",
" num_dots=NumDots,\n",
" top_left=top_left_list[ind],\n",
" loc_diff=loc_diff_list[ind])\n",
" else:\n",
" add_layer(patches, colors, size=size_list[ind],\n",
" num=num_show_list[ind],\n",
" top_left=top_left_list[ind], loc_diff=loc_diff_list[ind])\n",
" label(top_left_list[ind], text_list[ind] + '\\n{}@{}x{}'.format(\n",
" num_list[ind], size_list[ind][0], size_list[ind][1]))\n",
"\n",
" ############################\n",
" # in between layers\n",
" start_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]\n",
" end_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]\n",
" patch_size_list = [(5, 5), (2, 2), (5, 5), (2, 2)]\n",
" ind_bgn_list = range(len(patch_size_list))\n",
" text_list = ['Convolution', 'Max-pooling', 'Convolution', 'Max-pooling']\n",
"\n",
" for ind in range(len(patch_size_list)):\n",
" add_mapping(\n",
" patches, colors, start_ratio_list[ind], end_ratio_list[ind],\n",
" patch_size_list[ind], ind,\n",
" top_left_list, loc_diff_list, num_show_list, size_list)\n",
" label(top_left_list[ind], text_list[ind] + '\\n{}x{} kernel'.format(\n",
" patch_size_list[ind][0], patch_size_list[ind][1]), xy_off=[26, -65]\n",
" )\n",
"\n",
"\n",
" ############################\n",
" # fully connected layers\n",
" size_list = [(fc_unit_size, fc_unit_size)] * 3\n",
" num_list = [768, 500, 2]\n",
" num_show_list = list(map(min, num_list, [NumFcMax] * len(num_list)))\n",
" x_diff_list = [sum(x_diff_list) + layer_width, layer_width, layer_width]\n",
" top_left_list = np.c_[np.cumsum(x_diff_list), np.zeros(len(x_diff_list))]\n",
" loc_diff_list = [[fc_unit_size, -fc_unit_size]] * len(top_left_list)\n",
" text_list = ['Hidden\\nunits'] * (len(size_list) - 1) + ['Outputs']\n",
"\n",
" for ind in range(len(size_list)):\n",
" if flag_omit:\n",
" add_layer_with_omission(patches, colors, size=size_list[ind],\n",
" num=num_list[ind],\n",
" num_max=NumFcMax,\n",
" num_dots=NumDots,\n",
" top_left=top_left_list[ind],\n",
" loc_diff=loc_diff_list[ind])\n",
" else:\n",
" add_layer(patches, colors, size=size_list[ind],\n",
" num=num_show_list[ind],\n",
" top_left=top_left_list[ind],\n",
" loc_diff=loc_diff_list[ind])\n",
" label(top_left_list[ind], text_list[ind] + '\\n{}'.format(\n",
" num_list[ind]))\n",
"\n",
" text_list = ['Flatten\\n', 'Fully\\nconnected', 'Fully\\nconnected']\n",
"\n",
" for ind in range(len(size_list)):\n",
" label(top_left_list[ind], text_list[ind], xy_off=[-10, -65])\n",
"\n",
" ############################\n",
" for patch, color in zip(patches, colors):\n",
" patch.set_color(color * np.ones(3))\n",
" if isinstance(patch, Line2D):\n",
" ax.add_line(patch)\n",
" else:\n",
" patch.set_edgecolor(Black * np.ones(3))\n",
" ax.add_patch(patch)\n",
"\n",
" plt.tight_layout()\n",
" plt.axis('equal')\n",
" plt.axis('off')\n",
" plt.show()\n",
" fig.set_size_inches(8, 2.5)\n",
"\n",
" # fig_dir = './'\n",
" # fig_ext = '.png'\n",
" # fig.savefig(os.path.join(fig_dir, 'convnet_fig' + fig_ext),\n",
" # bbox_inches='tight', pad_inches=0)\n",
"\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}