Skip to content

Commit a470d9d

Browse files
zhengjiang shaozhengjiang shao
authored andcommitted
Migrate 3D visualization from mayavi to PyVista
- Replace mayavi with PyVista for plot_contour3d, plot_mcontour, plot_field - Add _build_structured_grid using lattice vectors for non-orthogonal cell support (fixes hexagonal cell ELFCAR issue #8) - Update example scripts to use PyVista - Add pyvista>=0.42.0 to dependencies - Remove outdated Canopy installation instructions - Fixes #18, fixes #8, fixes #4 Co-Authored-By: deepseek-v4-pro
1 parent a48bb43 commit a470d9d

6 files changed

Lines changed: 115 additions & 103 deletions

File tree

README.rst

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,13 @@ Installation
4646

4747
python setup.py install
4848

49-
If you want to use **mayavi** to visualize VASP data, it is recommened to install `Canopy environment <https://store.enthought.com/downloads/#default>`_ on your device instead of installing it manually.
49+
3D visualization is done via **PyVista**, which is installed automatically as a dependency::
5050

51-
After installing canopy, you can set corresponding aliases, for example:
52-
53-
.. code-block:: shell
54-
55-
alias canopy='/Users/<yourname>/Library/Enthought/Canopy/edm/envs/User/bin/python'
56-
alias canopy-pip='/Users/<yourname>/Library/Enthought/Canopy/edm/envs/User/bin/pip'
57-
alias canopy-ipython='/Users/<yourname>/Library/Enthought/Canopy/edm/envs/User/bin/ipython'
58-
alias canopy-jupyter='/Users/<yourname>/Library/Enthought/Canopy/edm/envs/User/bin/jupyter'
51+
pip install vaspy
5952

60-
Then you can install VASPy to canopy::
53+
Or install PyVista separately::
6154

62-
canopy-pip install vaspy
55+
pip install pyvista
6356

6457
Examples
6558
--------
@@ -98,7 +91,7 @@ Visualize ELFCAR
9891
>>> from vaspy.electro import ElfCar
9992
>>> a = ElfCar()
10093
>>> a.plot_contour() # Plot coutour
101-
>>> a.plot_mcontour() # Plot coutour using mlab(with Mayavi installed)
94+
>>> a.plot_mcontour() # Plot coutour using PyVista
10295
>>> a.plot_contour3d() # Plot 3D coutour
10396
>>> a.plot_field() # Plot scalar field
10497

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
matplotlib>=1.5.2
22
numpy>=1.11.1
33
scipy>=0.18.0
4-
4+
pyvista>=0.42.0

scripts/md_viz/map_mayavi.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
from scipy.interpolate import interp2d
7-
from mayavi import mlab
7+
import pyvista as pv
88

99
from vaspy.iter import XdatCar, AniFile
1010
from traj import get_trajectories
@@ -43,11 +43,15 @@ def locate(x, y, position):
4343
newy = np.linspace(left, right, interp_resolution)
4444
newz = interp_func(newx, newy)
4545

46-
newy, newx = np.meshgrid(newx, newy)
47-
48-
face = mlab.surf(newx, newy, newz, warp_scale=40)
49-
mlab.axes(xlabel="x", ylabel="y", zlabel="z")
50-
mlab.outline(face)
51-
52-
mlab.show()
53-
46+
# PyVista surface plot
47+
nx, ny = len(newx), len(newy)
48+
X, Y = np.meshgrid(newx, newy, indexing='ij')
49+
Z = newz.T
50+
surface = pv.StructuredGrid(X[:, :, None], Y[:, :, None], Z[:, :, None])
51+
surface.point_data['scalars'] = Z[:, :, None].flatten(order='F')
52+
53+
pl = pv.Plotter()
54+
pl.add_mesh(surface, scalars='scalars', cmap='viridis',
55+
show_scalar_bar=True)
56+
pl.add_axes(xlabel='x', ylabel='y', zlabel='z')
57+
pl.show()

scripts/plot_test.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from plotter import *
22
import numpy as np
33
from scipy.interpolate import interp2d
4-
from mayavi import mlab
4+
import pyvista as pv
55
import mpl_toolkits.mplot3d
66
import matplotlib.pyplot as plt
77

@@ -17,20 +17,14 @@
1717
newy = np.linspace(0, np.max(y), 1000)
1818
newz = interpfunc(newx, newy)
1919

20-
#extent = [np.min(newx), np.max(newx), np.min(newy), np.max(newy)]
21-
#plt.contourf(newx.reshape(-1), newy.reshape(-1), newz, 20, extent=extent)
22-
#plt.colorbar()
23-
#
24-
##3d plot
25-
#fig3d = plt.figure()
26-
#ax3d = fig3d.add_subplot(111, projection='3d')
27-
#ax3d.plot_surface(newx, newy, newz, cmap=plt.cm.RdBu_r)
28-
#
29-
#plt.show()
20+
# PyVista surface plot
21+
nx, ny = len(newx), len(newy)
22+
X, Y = np.meshgrid(newx, newy, indexing='ij')
23+
Z = newz.T
24+
surface = pv.StructuredGrid(X[:, :, None], Y[:, :, None], Z[:, :, None])
25+
surface.point_data['scalars'] = Z[:, :, None].flatten(order='F')
3026

31-
#mlab
32-
face = mlab.surf(newx, newy, newz, warp_scale=2)
33-
mlab.axes(xlabel='x', ylabel='y', zlabel='z')
34-
mlab.outline(face)
35-
36-
mlab.show()
27+
pl = pv.Plotter()
28+
pl.add_mesh(surface, scalars='scalars', cmap='viridis', show_scalar_bar=True)
29+
pl.add_axes(xlabel='x', ylabel='y', zlabel='z')
30+
pl.show()

setup.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,21 @@
5858
5959
python setup.py install
6060
61-
If you want to use **mayavi** to visualize VASP data, it is recommened to install `Canopy environment <https://store.enthought.com/downloads/#default>`_ on your device instead of installing it manually.
61+
3D visualization is done via **PyVista**, which is installed automatically as a dependency::
6262
63-
After installing canopy, you can set corresponding aliases, for example:
64-
65-
.. code-block:: shell
66-
67-
alias canopy='/Users/<yourname>/Library/Enthought/Canopy/edm/envs/User/bin/python'
68-
alias canopy-pip='/Users/zjshao/Library/Enthought/Canopy/edm/envs/User/bin/pip'
69-
alias canopy-ipython='/Users/<yourname>/Library/Enthought/Canopy/edm/envs/User/bin/ipython'
70-
alias canopy-jupyter='/Users/<yourname>/Library/Enthought/Canopy/edm/envs/User/bin/jupyter'
63+
pip install vaspy
7164
72-
Then you can install VASPy to canopy::
65+
Or install PyVista separately::
7366
74-
canopy-pip install vaspy
67+
pip install pyvista
7568
7669
"""
7770

7871
install_requires = [
7972
'numpy>=1.11.1',
8073
'matplotlib>=1.5.2',
8174
'scipy>=0.18.0',
75+
'pyvista>=0.42.0',
8276
]
8377

8478
license = 'LICENSE'

vaspy/electro.py

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626
print('Warning: Module matplotlib.pyplot is not installed')
2727
plt_installed = False
2828

29-
#whether mayavi installed
29+
#whether pyvista installed
3030
try:
31-
from mayavi import mlab
32-
mayavi_installed = True
31+
import pyvista as pv
32+
pyvista_installed = True
3333
except ImportError:
34-
mayavi_installed = False
34+
pyvista_installed = False
3535

3636
from vaspy.plotter import DataPlotter
3737
from vaspy.atomco import PosCar
@@ -262,8 +262,8 @@ def __init__(self, filename='ELFCAR'):
262262
------------- ame as PosCar ------------
263263
elf_data 3d array
264264
plot_contour method, use matplotlib to plot contours
265-
plot_mcontours method, use Mayavi.mlab to plot beautiful contour
266-
plot_contour3d method, use mayavi.mlab to plot 3d contour
265+
plot_mcontours method, use PyVista to plot beautiful contour
266+
plot_contour3d method, use PyVista to plot 3d contour
267267
plot_field method, plot scalar field for elf data
268268
============== =============================================
269269
"""
@@ -332,6 +332,28 @@ def expand_data(data, grid, widths):
332332

333333
return expanded_data, expanded_grid
334334

335+
def _build_structured_grid(self, data, grid):
336+
"""
337+
Build a pyvista.StructuredGrid from 3D data and POSCAR lattice vectors.
338+
339+
This properly handles non-orthogonal (e.g. hexagonal) cells by mapping
340+
fractional grid coordinates to Cartesian space using the lattice bases.
341+
"""
342+
nx, ny, nz = grid
343+
x_frac = np.linspace(0, 1, nx, endpoint=False)
344+
y_frac = np.linspace(0, 1, ny, endpoint=False)
345+
z_frac = np.linspace(0, 1, nz, endpoint=False)
346+
X, Y, Z = np.meshgrid(x_frac, y_frac, z_frac, indexing='ij')
347+
348+
bases = self.bases * self.bases_const
349+
cart_X = X * bases[0, 0] + Y * bases[1, 0] + Z * bases[2, 0]
350+
cart_Y = X * bases[0, 1] + Y * bases[1, 1] + Z * bases[2, 1]
351+
cart_Z = X * bases[0, 2] + Y * bases[1, 2] + Z * bases[2, 2]
352+
353+
pvgrid = pv.StructuredGrid(cart_X, cart_Y, cart_Z)
354+
pvgrid.point_data['values'] = data.flatten(order='F')
355+
return pvgrid
356+
335357
# 装饰器
336358
def contour_decorator(func):
337359
'''
@@ -437,29 +459,34 @@ def plot_contour(self, ndim0, ndim1, z, show_mode):
437459

438460
@contour_decorator
439461
def plot_mcontour(self, ndim0, ndim1, z, show_mode):
440-
"use mayavi.mlab to plot contour."
441-
if not mayavi_installed:
442-
self.__logger.info("Mayavi is not installed on your device.")
462+
"use PyVista to plot surface contour."
463+
if not pyvista_installed:
464+
self.__logger.info("PyVista is not installed on your device.")
443465
return
444466
#do 2d interpolation
445-
#get slice object
446467
s = np.s_[0:ndim0:1, 0:ndim1:1]
447468
x, y = np.ogrid[s]
448-
mx, my = np.mgrid[s]
449469
#use cubic 2d interpolation
450470
interpfunc = interp2d(x, y, z, kind='cubic')
451471
newx = np.linspace(0, ndim0, 600)
452472
newy = np.linspace(0, ndim1, 600)
453-
newz = interpfunc(newx, newy)
454-
#mlab
455-
face = mlab.surf(newx, newy, newz, warp_scale=2)
456-
mlab.axes(xlabel='x', ylabel='y', zlabel='z')
457-
mlab.outline(face)
473+
newz = interpfunc(newx, newy) # shape: (len(newy), len(newx))
474+
# Build structured surface (3D grid with thickness 1)
475+
nx, ny = len(newx), len(newy)
476+
X, Y = np.meshgrid(newx, newy, indexing='ij') # (nx, ny)
477+
Z = newz.T # transpose to (nx, ny)
478+
surface = pv.StructuredGrid(X[:, :, None], Y[:, :, None], Z[:, :, None])
479+
surface.point_data['scalars'] = Z[:, :, None].flatten(order='F')
480+
# Plot
481+
pl = pv.Plotter()
482+
pl.add_mesh(surface, scalars='scalars', cmap='viridis',
483+
show_scalar_bar=True)
484+
pl.add_axes(xlabel='x', ylabel='y', zlabel='z')
458485
#save or show
459486
if show_mode == 'show':
460-
mlab.show()
487+
pl.show()
461488
elif show_mode == 'save':
462-
mlab.savefig('mlab_contour3d.png')
489+
pl.screenshot('pyvista_contour3d.png')
463490
else:
464491
raise ValueError('Unrecognized show mode parameter : ' +
465492
show_mode)
@@ -468,7 +495,7 @@ def plot_mcontour(self, ndim0, ndim1, z, show_mode):
468495

469496
def plot_contour3d(self, **kwargs):
470497
'''
471-
use mayavi.mlab to plot 3d contour.
498+
use PyVista to plot 3d isosurface contour.
472499
473500
Parameter
474501
---------
@@ -480,61 +507,61 @@ def plot_contour3d(self, **kwargs):
480507
number of replication on x, y, z axis,
481508
}
482509
'''
483-
if not mayavi_installed:
484-
self.__logger.warning("Mayavi is not installed on your device.")
510+
if not pyvista_installed:
511+
self.__logger.warning("PyVista is not installed on your device.")
485512
return
486513
# set parameters
487514
widths = kwargs['widths'] if 'widths' in kwargs else (1, 1, 1)
488515
elf_data, grid = self.expand_data(self.elf_data, self.grid, widths)
489-
# import pdb; pdb.set_trace()
490516
maxdata = np.max(elf_data)
491517
maxct = kwargs['maxct'] if 'maxct' in kwargs else maxdata
492-
# check maxct
493518
if maxct > maxdata:
494519
self.__logger.warning("maxct is larger than %f", maxdata)
495520
opacity = kwargs['opacity'] if 'opacity' in kwargs else 0.6
496521
nct = kwargs['nct'] if 'nct' in kwargs else 5
497-
# plot surface
498-
surface = mlab.contour3d(elf_data)
499-
# set surface attrs
500-
surface.actor.property.opacity = opacity
501-
surface.contour.maximum_contour = maxct
502-
surface.contour.number_of_contours = nct
503-
# reverse axes labels
504-
mlab.axes(xlabel='z', ylabel='y', zlabel='x') # 是mlab参数顺序问题?
505-
mlab.outline()
506-
mlab.show()
522+
# Build StructuredGrid with proper cell geometry
523+
pvgrid = self._build_structured_grid(elf_data, grid)
524+
# Extract isosurfaces
525+
contours = pvgrid.contour(nct, scalars='values',
526+
rng=(0, maxct) if maxct < maxdata else None)
527+
# Plot
528+
pl = pv.Plotter()
529+
pl.add_mesh(contours, opacity=opacity, cmap='viridis',
530+
show_scalar_bar=True)
531+
pl.add_axes(xlabel='a', ylabel='b', zlabel='c')
532+
pl.show()
507533

508534
return
509535

510536
def plot_field(self, **kwargs):
511-
"plot scalar field for elf data"
512-
if not mayavi_installed:
513-
self.__logger.warning("Mayavi is not installed on your device.")
537+
"Plot scalar field volume with interactive cut plane."
538+
if not pyvista_installed:
539+
self.__logger.warning("PyVista is not installed on your device.")
514540
return
515541
# set parameters
516542
vmin = kwargs['vmin'] if 'vmin' in kwargs else 0.0
517543
vmax = kwargs['vmax'] if 'vmax' in kwargs else 1.0
518-
axis_cut = kwargs['axis_cut'] if 'axis_cut' in kwargs else 'z'
544+
axis_cut = kwargs.get('axis_cut', 'z')
519545
nct = kwargs['nct'] if 'nct' in kwargs else 5
520546
widths = kwargs['widths'] if 'widths' in kwargs else (1, 1, 1)
521547
elf_data, grid = self.expand_data(self.elf_data, self.grid, widths)
522-
#create pipeline
523-
field = mlab.pipeline.scalar_field(elf_data) # data source
524-
mlab.pipeline.volume(field, vmin=vmin, vmax=vmax) # put data into volumn to visualize
525-
#cut plane
526-
if axis_cut in ['Z', 'z']:
527-
plane_orientation = 'z_axes'
528-
elif axis_cut in ['Y', 'y']:
529-
plane_orientation = 'y_axes'
530-
elif axis_cut in ['X', 'x']:
531-
plane_orientation = 'x_axes'
532-
cut = mlab.pipeline.scalar_cut_plane(
533-
field.children[0], plane_orientation=plane_orientation)
534-
cut.enable_contours = True # 开启等值线显示
535-
cut.contour.number_of_contours = nct
536-
mlab.show()
537-
#mlab.savefig('field.png', size=(2000, 2000))
548+
# Build StructuredGrid with proper cell geometry
549+
pvgrid = self._build_structured_grid(elf_data, grid)
550+
# Determine cut plane normal
551+
normals = {'x': (1, 0, 0), 'y': (0, 1, 0), 'z': (0, 0, 1)}
552+
normal = normals.get(axis_cut.lower(), (0, 0, 1))
553+
# Slice through center
554+
center = pvgrid.center
555+
single_slice = pvgrid.slice(normal=normal, origin=center)
556+
# Contours on the slice
557+
edges = single_slice.contour(nct, scalars='values')
558+
# Plot
559+
pl = pv.Plotter()
560+
pl.add_volume(pvgrid, scalars='values', clim=(vmin, vmax),
561+
cmap='viridis', opacity='linear')
562+
pl.add_mesh(edges, color='black', line_width=1)
563+
pl.add_axes(xlabel='a', ylabel='b', zlabel='c')
564+
pl.show()
538565

539566
return
540567

0 commit comments

Comments
 (0)