Skip to content

Commit d64c348

Browse files
committed
Add graphite diamond example
1 parent 75a6a59 commit d64c348

File tree

4 files changed

+415
-121
lines changed

4 files changed

+415
-121
lines changed

example/python_pkg/Al_learn/sort_structures.ipynb

Lines changed: 28 additions & 121 deletions
Large diffs are not rendered by default.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
C4
2+
1.0
3+
0.0000000000000000 4.2546406144450799 0.0000000000000000
4+
2.4565648800000002 0.0000000000000000 0.0000000000000000
5+
0.0000000000000000 -1.3790696309310659 -3.5028300786042923
6+
C
7+
4
8+
direct
9+
0.1666444699999990 0.0000000000000000 0.9999069600000000 C0+
10+
0.8333555300000001 0.0000000000000000 0.0000930400000000 C0+
11+
0.6666444699999990 0.5000000000000000 0.9999069600000000 C0+
12+
0.3333555300000000 0.5000000000000000 0.0000930400000000 C0+

example/python_pkg/C_learn/DRAFFLE/learn_graphite_diamond.py renamed to example/python_pkg/C_learn/Dgraphite_diamond/learn.py

File renamed without changes.
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import matplotlib.pyplot as plt\n",
10+
"%matplotlib inline\n",
11+
"\n",
12+
"# matplotlib.use(\"Agg\")\n",
13+
"\n",
14+
"from ase import Atoms\n",
15+
"from ase.build import bulk\n",
16+
"from ase.io import read\n",
17+
"from agox.databases import Database\n",
18+
"from agox.environments import Environment\n",
19+
"from agox.utils.graph_sorting import Analysis\n",
20+
"\n",
21+
"import numpy as np\n",
22+
"from sklearn.decomposition import PCA"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"## Set up the plotting environment\n",
32+
"# matplotlib.rcParams.update(matplotlib.rcParamsDefault)\n",
33+
"plt.rc('text', usetex=True)\n",
34+
"plt.rc('font', family='cmr10', size=12)\n",
35+
"plt.rcParams[\"axes.formatter.use_mathtext\"] = True"
36+
]
37+
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": null,
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"## Set the plotting parameters\n",
45+
"seed = 0\n",
46+
"identifier = \"\"\n",
47+
"# min_energy = -9.064090728759766"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"## Set the descriptors\n",
57+
"from agox.models.descriptors import SOAP\n",
58+
"local_descriptor = local_descriptor = SOAP.from_species([\"C\"], r_cut=5.0)"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": null,
64+
"metadata": {},
65+
"outputs": [],
66+
"source": [
67+
"## Set the calculators\n",
68+
"from chgnet.model import CHGNetCalculator\n",
69+
"from ase.calculators.singlepoint import SinglePointCalculator\n",
70+
"calc = CHGNetCalculator()"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": null,
76+
"metadata": {},
77+
"outputs": [],
78+
"source": [
79+
"## Load the unrelaxed structures\n",
80+
"unrlxd_structures = read(\"DTMP\"+identifier+\"/unrlxd_structures_seed\"+str(seed)+\".traj\", index=\":\")\n",
81+
"for structure in unrlxd_structures:\n",
82+
" structure.calc = calc"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": null,
88+
"metadata": {},
89+
"outputs": [],
90+
"source": [
91+
"## Load the relaxed structures\n",
92+
"rlxd_structures = read(\"DTMP\"+identifier+\"/rlxd_structures_seed\"+str(seed)+\".traj\", index=\":\")\n",
93+
"for structure in rlxd_structures:\n",
94+
" structure.calc = calc"
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": null,
100+
"metadata": {},
101+
"outputs": [],
102+
"source": [
103+
"# read energies from energies_unrlxd_seed0.txt and add to the respective structures using a SinglePointCalculator\n",
104+
"# the file has the form \"index energy\"\n",
105+
"filename = \"DTMP\"+identifier+\"/energies_unrlxd_seed\"+str(seed)+\".txt\"\n",
106+
"with open(filename) as f:\n",
107+
" for line in f:\n",
108+
" index, energy = line.split()\n",
109+
" index = int(index)\n",
110+
" energy = float(energy)\n",
111+
" unrlxd_structures[index].calc = SinglePointCalculator(unrlxd_structures[index], energy=energy * len(unrlxd_structures[index]))\n",
112+
"\n",
113+
"\n",
114+
"filename = \"DTMP\"+identifier+\"/energies_rlxd_seed\"+str(seed)+\".txt\"\n",
115+
"with open(filename) as f:\n",
116+
" for line in f:\n",
117+
" index, energy = line.split()\n",
118+
" index = int(index)\n",
119+
" energy = float(energy)\n",
120+
" rlxd_structures[index].calc = SinglePointCalculator(rlxd_structures[index], energy=energy * len(rlxd_structures[index]))"
121+
]
122+
},
123+
{
124+
"cell_type": "code",
125+
"execution_count": null,
126+
"metadata": {},
127+
"outputs": [],
128+
"source": [
129+
"diamond = bulk(\"C\", \"diamond\", a=3.567) # Lattice constant for diamond cubic carbon\n",
130+
"diamond.calc = calc\n",
131+
"diamond_energy = diamond.get_potential_energy()\n",
132+
"diamond_energy_per_atom = diamond_energy / len(diamond)\n",
133+
"\n",
134+
"graphite = read(\"graphite.vasp\")\n",
135+
"graphite.calc = calc\n",
136+
"graphite_energy = graphite.get_potential_energy()\n",
137+
"graphite_energy_per_atom = graphite_energy / len(graphite)"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"# Calculate energies per atom for the relaxed structures\n",
147+
"energies_per_atom = [structure.get_potential_energy() / len(structure) for structure in rlxd_structures]\n",
148+
"min_energy = np.min(energies_per_atom)\n",
149+
"rlxd_delta_en_per_atom = np.array(energies_per_atom) - min_energy\n",
150+
"print(\"Relaxed min energy: \", np.min(energies_per_atom))"
151+
]
152+
},
153+
{
154+
"cell_type": "code",
155+
"execution_count": null,
156+
"metadata": {},
157+
"outputs": [],
158+
"source": [
159+
"# Calculate energies per atom for the unrelaxed structures\n",
160+
"energies_per_atom = [structure.get_potential_energy() / len(structure) for structure in unrlxd_structures]\n",
161+
"unrlxd_delta_en_per_atom = np.array(energies_per_atom) - min_energy\n",
162+
"print(\"Unrelaxed min energy: \", np.min(energies_per_atom))"
163+
]
164+
},
165+
{
166+
"cell_type": "code",
167+
"execution_count": null,
168+
"metadata": {},
169+
"outputs": [],
170+
"source": [
171+
"if abs( np.min(energies_per_atom) - min_energy ) > 5e-2:\n",
172+
" print(\"Minimum energy per atom is not zero. Check the energy calculation.\")"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": null,
178+
"metadata": {},
179+
"outputs": [],
180+
"source": [
181+
"## Set up the PCA\n",
182+
"pca = PCA(n_components=2)"
183+
]
184+
},
185+
{
186+
"cell_type": "code",
187+
"execution_count": null,
188+
"metadata": {},
189+
"outputs": [],
190+
"source": [
191+
"## Fit the PCA model to the unrelaxed or relaxed structures\n",
192+
"rlxd_string = \"rlxd\""
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": null,
198+
"metadata": {},
199+
"outputs": [],
200+
"source": [
201+
"## Get the 'super atom' descriptors for the unrelaxed structures\n",
202+
"unrlxd_super_atoms = []\n",
203+
"for structure in unrlxd_structures:\n",
204+
" unrlxd_super_atoms.append( np.mean(local_descriptor.get_features(structure), axis=0) )"
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": null,
210+
"metadata": {},
211+
"outputs": [],
212+
"source": [
213+
"## Get the 'super atom' descriptors for the relaxed structures\n",
214+
"rlxd_super_atoms = []\n",
215+
"for structure in rlxd_structures:\n",
216+
" rlxd_super_atoms.append( np.mean(local_descriptor.get_features(structure), axis=0) )"
217+
]
218+
},
219+
{
220+
"cell_type": "code",
221+
"execution_count": null,
222+
"metadata": {},
223+
"outputs": [],
224+
"source": [
225+
"## Save pca model\n",
226+
"import pickle\n",
227+
"if True:\n",
228+
" pca.fit(np.squeeze([arr for arr in rlxd_super_atoms]))\n",
229+
" with open(\"pca_model_all_rlxd_\"+str(seed)+\".pkl\", \"wb\") as f:\n",
230+
" pickle.dump(pca, f)\n",
231+
"\n",
232+
"## Load pca model\n",
233+
"with open(\"pca_model_all_\"+rlxd_string+\"_0.pkl\", \"rb\") as f:\n",
234+
" pca = pickle.load(f)"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": null,
240+
"metadata": {},
241+
"outputs": [],
242+
"source": [
243+
"# Get super atom descriptors for diamond and graphite\n",
244+
"graphite_super_atoms = [ np.mean(local_descriptor.get_features(graphite), axis=0) ]\n",
245+
"diamond_super_atoms = [ np.mean(local_descriptor.get_features(diamond), axis=0) ]"
246+
]
247+
},
248+
{
249+
"cell_type": "code",
250+
"execution_count": null,
251+
"metadata": {},
252+
"outputs": [],
253+
"source": [
254+
"## Transform the unrelaxed and relaxed structures to the reduced space\n",
255+
"unrlxd_X_reduced = pca.transform(np.squeeze([arr for arr in unrlxd_super_atoms]))\n",
256+
"rlxd_X_reduced = pca.transform(np.squeeze([arr for arr in rlxd_super_atoms]))\n",
257+
"graphite_X_reduced = pca.transform([np.squeeze([graphite_super_atoms])])\n",
258+
"diamond_X_reduced = pca.transform([np.squeeze([diamond_super_atoms])])"
259+
]
260+
},
261+
{
262+
"cell_type": "code",
263+
"execution_count": null,
264+
"metadata": {},
265+
"outputs": [],
266+
"source": [
267+
"## Get the index of the structure with the minimum energy\n",
268+
"min_energy_index = np.argmin(rlxd_delta_en_per_atom)\n",
269+
"print(min_energy_index)"
270+
]
271+
},
272+
{
273+
"cell_type": "code",
274+
"execution_count": null,
275+
"metadata": {},
276+
"outputs": [],
277+
"source": [
278+
"## Plot the PCA\n",
279+
"fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 6))\n",
280+
"\n",
281+
"plt.subplots_adjust(wspace=0.05, hspace=0)\n",
282+
"\n",
283+
"## Get the maximum energy for the colourbar\n",
284+
"max_en = min(3.5, max(np.max(unrlxd_delta_en_per_atom), np.max(rlxd_delta_en_per_atom)))\n",
285+
"\n",
286+
"## Plot the PCA\n",
287+
"axes[0].scatter(unrlxd_X_reduced[:, 0], unrlxd_X_reduced[:, 1], c=unrlxd_delta_en_per_atom, cmap=\"viridis\", vmin = 0, vmax = max_en)\n",
288+
"axes[1].scatter(rlxd_X_reduced[:, 0], rlxd_X_reduced[:, 1], c=rlxd_delta_en_per_atom, cmap=\"viridis\", vmin = 0, vmax = max_en)\n",
289+
"\n",
290+
"## Add the minimum energy structures to the plot\n",
291+
"for ax in axes:\n",
292+
" ax.scatter(diamond_X_reduced[0,0], diamond_X_reduced[0,1], s=200, edgecolor=[1.0, 0.0, 0.0, 0.5], facecolor='none', linewidth=2, label='diamond')\n",
293+
" ax.scatter(graphite_X_reduced[0,0], graphite_X_reduced[0,1], s=200, edgecolor=[1.0, 0.0, 0.0, 1.0], facecolor='none', linewidth=2, label='graphite')\n",
294+
" ax.legend(fontsize=10)\n",
295+
" handles, labels = ax.get_legend_handles_labels()\n",
296+
" ax.legend(handles[::-1], labels[::-1], facecolor='white', framealpha=1.0, edgecolor='black', fancybox=False, loc='lower right')\n",
297+
"\n",
298+
"## Add labels\n",
299+
"fig.text(0.5, 0.04, 'Principal Component 1', ha='center', fontsize=15)\n",
300+
"axes[0].set_ylabel('Principal Component 2', fontsize=15)\n",
301+
"axes[0].set_title('Unrelaxed')\n",
302+
"axes[1].set_title('Relaxed')\n",
303+
"if identifier == \"_VASP\":\n",
304+
" if rlxd_string == \"rlxd\":\n",
305+
" xlims = [-11, 8]\n",
306+
" ylims = [-5, 6]\n",
307+
" else:\n",
308+
" xlims = [-9, 13]\n",
309+
" ylims = [-7, 12]\n",
310+
"else:\n",
311+
" if rlxd_string == \"rlxd\":\n",
312+
" xlims = [-310, 310]\n",
313+
" ylims = [-53, 53]\n",
314+
" else:\n",
315+
" xlims = [-5, 13]\n",
316+
" ylims = [-6.5, 13]\n",
317+
"\n",
318+
"for ax in axes:\n",
319+
" ax.tick_params(axis='both', direction='in')\n",
320+
" ax.set_xlim(xlims)\n",
321+
" ax.set_ylim(ylims)\n",
322+
"\n",
323+
"## Unify tick labels\n",
324+
"xticks = axes[0].get_xticks()\n",
325+
"xticks = xticks[(xticks >= xlims[0]) & (xticks <= xlims[1])]\n",
326+
"\n",
327+
"axes[1].set_xticks(xticks)\n",
328+
"axes[1].set_yticklabels([])\n",
329+
"axes[0].tick_params(axis='x', labelbottom=True, top=True)\n",
330+
"axes[1].tick_params(axis='x', labelbottom=True, top=True)\n",
331+
"axes[0].tick_params(axis='y', labelbottom=True, right=True)\n",
332+
"axes[1].tick_params(axis='y', labelbottom=True, right=True)\n",
333+
"\n",
334+
"## Make axes[0] and axes[1] the same width\n",
335+
"axes[0].set_box_aspect(1.7)\n",
336+
"axes[1].set_box_aspect(1.7)\n",
337+
"\n",
338+
"## Add colorbar next to the axes\n",
339+
"cbar = fig.colorbar(axes[1].collections[0], ax=axes, orientation='vertical', fraction=0.085, pad=0.02)\n",
340+
"cbar.set_label('Formation energy (eV/atom)', fontsize=15)\n",
341+
"\n",
342+
"## Save the figure\n",
343+
"plt.savefig('C_RAFFLE'+identifier+'_pca_'+rlxd_string+'_fit_seed'+str(seed)+'.pdf', bbox_inches='tight', pad_inches=0, facecolor=fig.get_facecolor(), edgecolor='none')"
344+
]
345+
},
346+
{
347+
"cell_type": "code",
348+
"execution_count": null,
349+
"metadata": {},
350+
"outputs": [],
351+
"source": []
352+
}
353+
],
354+
"metadata": {
355+
"kernelspec": {
356+
"display_name": "raffle_env",
357+
"language": "python",
358+
"name": "python3"
359+
},
360+
"language_info": {
361+
"codemirror_mode": {
362+
"name": "ipython",
363+
"version": 3
364+
},
365+
"file_extension": ".py",
366+
"mimetype": "text/x-python",
367+
"name": "python",
368+
"nbconvert_exporter": "python",
369+
"pygments_lexer": "ipython3",
370+
"version": "3.12.8"
371+
}
372+
},
373+
"nbformat": 4,
374+
"nbformat_minor": 2
375+
}

0 commit comments

Comments
 (0)