{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "eaed683c-c6f1-45e4-aaee-ae4e57209f5f", "metadata": {}, "outputs": [], "source": [ "import scipy as scp\n", "from scipy.spatial import Voronoi, voronoi_plot_2d\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n", "%matplotlib widget" ] }, { "cell_type": "code", "execution_count": 2, "id": "efd84ecd-9fb3-4d2f-9f7a-47cd1ff09eea", "metadata": {}, "outputs": [], "source": [ "class Lattice:\n", " def __init__(self, *vecs):\n", " # if the vecs were put in an iterable\n", " if len(vecs) == 1:\n", " vecs = vecs[0]\n", " if len(vecs) == 3:\n", " pass\n", " elif len(vecs) == 2:\n", " pass\n", " else: raise ValueError(\"Vecs must contain either 2 or 3 vectors\")\n", " self.dim = len(vecs)\n", " self.vecs = list(vecs)\n", " for i, v in enumerate(self.vecs):\n", " if type(v) != np.ndarray:\n", " self.vecs[i] = np.array(v)\n", " if self.vecs[i].shape != (self.dim,):\n", " raise ValueError(f\"Got {self.dim} vectors, therefore all vectors must be {self.dim} dimensional but vector {i+1} has shape {self.vecs[i].shape}\")\n", " self.vecs = np.array(self.vecs)\n", " self.vec_lengths = np.array([np.linalg.norm(v) for v in self.vecs])\n", " self.center = np.zeros(self.dim)\n", "\n", " def get_point(self, *ns):\n", " if len(ns) != len(self.vecs): raise ValueError(f\"Requires one index for each lattice vector {len(self.vecs)}, but got only {ns}\")\n", " point = self.center.copy()\n", " for i, n in enumerate(ns):\n", " point += n * self.vecs[i]\n", " return point\n", "\n", " \n", " def get_points_around_center(self, n):\n", " points = []\n", " import itertools\n", " ns = [i for i in range(-n, n+1)]\n", " for n in itertools.product(*[ns for _ in range(self.dim)]):\n", " # print(n)\n", " points.append(self.get_point(*n))\n", " return points" ] }, { "cell_type": "code", "execution_count": 3, "id": "515b9133-233b-4a6a-8b2f-a2fca10fea3d", "metadata": {}, "outputs": [], "source": [ "def rot_mat_2D(rad):\n", " return np.array([[np.cos(rad), -np.sin(rad)], [np.sin(rad), np.cos(rad)]])\n", "\n", "\n", "def get_reciprocal_lattice(lattice: Lattice):\n", " if lattice.dim == 2:\n", " rot_90_deg = rot_mat_2D(np.pi / 2)\n", " a1, a2 = lattice.vecs\n", " b1 = 2 * np.pi * rot_90_deg @ a2 / (np.dot(a1, rot_90_deg @ a2))\n", " b2 = 2 * np.pi * rot_90_deg @ a1 / (np.dot(a2, rot_90_deg @ a1))\n", " return Lattice(b1, b2)\n", " elif lattice.dim == 3:\n", " a1, a2, a3 = lattice.vecs\n", " V = np.dot(a1, np.cross(a2, a3))\n", " b1 = 2 * np.pi/V * np.cross(a2, a3)\n", " b2 = 2 * np.pi/V * np.cross(a3, a1)\n", " b3 = 2 * np.pi/V * np.cross(a1, a2)\n", " return Lattice(b1, b2, b3)\n", " else: raise NotImplementedError(f\"Dim must be 2 or 3, but is {lattice.dim}\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "2855d08e-70d8-4ef7-ba19-2c06316caf15", "metadata": {}, "outputs": [], "source": [ "\n", "def get_elementary_cell(lattice: Lattice):\n", " points = lattice.get_points_around_center(1)\n", " return points\n", "\n", "def get_orthogonal_2D(vec):\n", " return np.array((vec[1], -vec[0]))\n", "\n", "def get_unit_cell_vertices(lattice: Lattice, voronoi: Voronoi):\n", " \"\"\"regard only voronoi vertices which are closest to the center <=> their norm is <= 0.5*(norm of the unit vectors added together\n", " \"\"\"\n", " lattice_vec_norm = np.sqrt(np.sum(lattice.vec_lengths**2))\n", " return voronoi.vertices[np.linalg.norm(voronoi.vertices, axis=1) <= 0.5 * lattice_vec_norm]\n", " \n", "def plot_unit_cell(lattice, fig_ax=None, vec_label=\"a\", subplot_kw={}):\n", " # get voronoi of the points around the center\n", " points = get_elementary_cell(lattice)\n", " voronoi = Voronoi(points)\n", "\n", " if fig_ax:\n", " fig, ax = fig_ax\n", " else:\n", " if lattice.dim == 3:\n", " fig = plt.figure()\n", " ax = fig.add_subplot(1,1,1, projection=\"3d\") \n", " else:\n", " fig, ax = plt.subplots(**subplot_kw) \n", "\n", " # unit cell vertices\n", " cell_points = get_unit_cell_vertices(lattice, voronoi)\n", " # sort by polar angle for the fill function\n", " cell_points = list(cell_points)\n", " # print(cell_points)\n", " # print([i for i in map(lambda p: np.arctan2(p[1],p[0]), cell_points)])\n", " if lattice.dim == 2:\n", " cell_points.sort(key=lambda p: np.arctan2(p[1],p[0]))\n", " x_cell, y_cell = zip(*cell_points)\n", " ax.fill(x_cell, y_cell, color=\"#4444\")\n", " ax.scatter(x_cell, y_cell, color=\"orange\")\n", " \n", " # lattice points\n", " x_lat, y_lat = zip(*lattice.get_points_around_center(3))\n", " ax.scatter(x_lat, y_lat, color=\"blue\")\n", " \n", " # lattice vectors\n", " arrowprops = dict(arrowstyle=\"-|>,head_width=0.4,head_length=0.8\", color=\"black\", shrinkA=0,shrinkB=0)\n", " for i, vec in enumerate(lattice.vecs):\n", " ax.annotate(f\"\", xy=lattice.vecs[i], xytext=lattice.center, arrowprops=arrowprops)\n", " if vec_label is not None:\n", " # add name of vector at a perpendicular offset starting at half length\n", " ax.annotate(r\"$\\vec{\"+f\"{vec_label}}}_{i+1}$\", xy=0.7*lattice.vecs[i], xytext=0.7*lattice.vecs[i] + 0.06*get_orthogonal_2D(lattice.vecs[i]))\n", " elif lattice.dim == 3:\n", " # todo filter so that only\n", " ridges = voronoi.ridge_vertices\n", " lattice_vec_norm = np.sqrt(np.sum(lattice.vec_lengths**2))\n", " for ridge in ridges:\n", " verts = voronoi.vertices[ridge]\n", " # TODO: doesnt seem to work\n", " \"\"\"regard only voronoi vertices which are closest to the center <=> their norm is <= 0.5*(norm of the unit vectors added together\n", " \"\"\"\n", " verts = verts[np.linalg.norm(verts, axis=1) <= 0.5 * lattice_vec_norm]\n", " x_lat, y_lat, z_lat = zip(*lattice.get_points_around_center(1))\n", " ax.scatter(x_lat, y_lat, z_lat, color=\"red\", marker=\".\")\n", " # print(verts, type(verts), verts.shape, verts.ndim)\n", " #ax.add_collection3d(Poly3DCollection([voronoi.vertices[ridge]], edgecolor=\"black\"))\n", " for vec in lattice.vecs:\n", " ax.plot(*[i for i in zip([0,0,0], vec)])\n", " ax.set_xlim(-2, 2)\n", " ax.set_ylim(-2, 2)\n", " ax.set_zlim(-2, 2)\n", " else: raise NotImplementedError(f\"Dim must be 2 or 3, but is {lattice.dim}\")\n", "\n", " # limit to 2*lattice vectors\n", " def calc_lim(axis):\n", " lim = 2.05 * np.max(np.abs(lattice.vecs[axis,:]))\n", " return -lim, lim\n", " ax.set_xlim(*calc_lim(0))\n", " ax.set_ylim(*calc_lim(1))\n", " if lattice.dim == 3: ax.set_zlim(*calc_lim(2))\n", " fig.tight_layout()\n", " return fig" ] }, { "cell_type": "code", "execution_count": 5, "id": "c2faef81-5f2a-4950-b986-e891ddfa7da8", "metadata": {}, "outputs": [], "source": [ "sphere_point = lambda rad: np.array([np.cos(rad), np.sin(rad)])\n", "\n", "square_lattice = Lattice([1, 0], [0, 1])\n", "tilted_lattice = Lattice([1, 0.5], [0, 1])\n", "honeycomb_lattice = Lattice(sphere_point(0), sphere_point(np.pi * 2/3))" ] }, { "cell_type": "code", "execution_count": 6, "id": "a6b6ccf4-34c7-419d-8bcc-f3ce2870ba25", "metadata": {}, "outputs": [], "source": [ "# fig, axs = plt.subplots(3, figsize=(4, 12))\n", "# plot_unit_cell(square_lattice, fig_ax=(fig, axs[0]))\n", "# plot_unit_cell(tilted_lattice, fig_ax=(fig, axs[1]))\n", "# plot_unit_cell(honeycomb_lattice, fig_ax=(fig,axs[2]));" ] }, { "cell_type": "code", "execution_count": 7, "id": "5314dd3a-17a7-402d-b3ee-ddbf69a6c4ed", "metadata": {}, "outputs": [], "source": [ "def plot_lattice(lattice: Lattice):\n", " reci = get_reciprocal_lattice(lattice)\n", " print(reci.vecs)\n", " if lattice.dim == 3:\n", " fig = plt.figure()\n", " axs = [fig.add_subplot(1,2,i, projection=\"3d\") for i in [1,2]]\n", " fig.suptitle(\"3D Lattice\")\n", " \n", " else:\n", " fig, axs = plt.subplots(1, 2)\n", " plot_unit_cell(lattice, fig_ax=(fig, axs[0]))\n", " plot_unit_cell(reci, fig_ax=(fig, axs[1]), vec_label=\"b\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "39a49d00-fc93-4f2a-8645-d5e440e7a092", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 6.28318531e+00 -3.84734139e-16]\n", " [ 3.84734139e-16 6.28318531e+00]]\n", "[[ 6.28318531e+00 -3.84734139e-16]\n", " [-3.14159265e+00 6.28318531e+00]]\n", "[[6.28318531e+00 3.62759873e+00]\n", " [4.44252717e-16 7.25519746e+00]]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9509fdd0b81b47a7bdcdd99704c382a7", "version_major": 2, "version_minor": 0 }, "image/png": "", "text/html": [ "\n", "