Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump min supported Python to 3.10 #195

Merged
merged 14 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,43 @@ jobs:
uses: janosh/workflows/.github/workflows/pytest-release.yml@main
with:
os: ${{ matrix.os }}
python-version: 3.9
python-version: "3.10"
secrets: inherit

find-scripts:
runs-on: ubuntu-latest
outputs:
script_list: ${{ steps.set-matrix.outputs.script_list }}
steps:
- name: Check out repository
uses: actions/checkout@v4

- name: Find Python scripts
id: set-matrix
run: |
SCRIPTS=$(find examples/make_assets -name "*.py" | jq -R -s -c 'split("\n")[:-1]')
echo "script_list=$SCRIPTS" >> $GITHUB_OUTPUT

test-scripts:
needs: find-scripts
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
script: ${{fromJson(needs.find-scripts.outputs.script_list)}}
steps:
- name: Check out repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install package and dependencies
run: pip install -e .[make-assets]

- name: Run script
run: python ${{ matrix.script }}
env:
MP_API_KEY: ${{ secrets.MP_API_KEY }}
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.1
rev: v0.6.3
hooks:
- id: ruff
args: [--fix]
Expand All @@ -17,7 +17,7 @@ repos:
types_or: [python, jupyter]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.1
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies: [types-requests]
Expand Down Expand Up @@ -73,7 +73,7 @@ repos:
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.9.0
rev: v9.9.1
hooks:
- id: eslint
types: [file]
Expand All @@ -87,6 +87,6 @@ repos:
- typescript-eslint

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.376
rev: v1.1.378
hooks:
- id: pyright
16 changes: 12 additions & 4 deletions examples/dataset_exploration/matpes/eda.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,21 @@
df_pbe[Key.forces] = df_pbe[Key.forces].map(np.abs)

df_r2scan_elem_forces = pd.DataFrame(
{site.specie.symbol: np.linalg.norm(force) for site, force in zip(struct, forces)}
for struct, forces in zip(df_r2scan[Key.structure], df_r2scan[Key.forces])
{
site.specie.symbol: np.linalg.norm(force)
for site, force in zip(struct, forces, strict=True)
}
for struct, forces in zip(
df_r2scan[Key.structure], df_r2scan[Key.forces], strict=True
)
).mean()

df_pbe_elem_forces = pd.DataFrame(
{site.specie.symbol: np.linalg.norm(force) for site, force in zip(struct, forces)}
for struct, forces in zip(df_pbe[Key.structure], df_pbe[Key.forces])
{
site.specie.symbol: np.linalg.norm(force)
for site, force in zip(struct, forces, strict=True)
}
for struct, forces in zip(df_pbe[Key.structure], df_pbe[Key.forces], strict=True)
).mean()


Expand Down
6 changes: 6 additions & 0 deletions examples/make_assets/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from pymatviz.enums import Key


try:
import ffonons # noqa: F401
except ImportError:
raise SystemExit(0) from None # install ffonons to run this script


# %% Plot phonon bands and DOS
for mp_id, formula in (
("mp-2758", "Sr4Se4"),
Expand Down
4 changes: 2 additions & 2 deletions examples/make_assets/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
np_rng = np.random.default_rng(seed=0)
y_true = np_rng.normal(5, 4, rand_regression_size)
y_pred = 1.2 * y_true - 2 * np_rng.normal(0, 1, rand_regression_size)
y_std = (y_true - y_pred) * 10 * np_rng.normal(0, 0.1, rand_regression_size)
y_std = abs((y_true - y_pred) * 10 * np_rng.normal(0, 0.1, rand_regression_size))


# %% density scatter plotly
Expand All @@ -42,7 +42,7 @@
xs, ys = make_blobs(n_samples=100_000, centers=3, n_features=2, random_state=42)

x_col, y_col, target_col = "feature1", "feature2", "target"
df_blobs = pd.DataFrame(dict(zip([x_col, y_col], xs.T)) | {target_col: ys})
df_blobs = pd.DataFrame(dict(zip([x_col, y_col], xs.T, strict=True)) | {target_col: ys})

fig = pmv.density_scatter_plotly(df=df_blobs, x=x_col, y=y_col)
fig.show()
Expand Down
4 changes: 2 additions & 2 deletions examples/make_assets/uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@


# %% Cumulative Plots
ax = pmv.cumulative_error(y_pred, y_true)
ax = pmv.cumulative_error(y_pred - y_true)
pmv.io.save_and_compress_svg(ax, "cumulative-error")


ax = pmv.cumulative_residual(y_pred, y_true)
ax = pmv.cumulative_residual(y_pred - y_true)
pmv.io.save_and_compress_svg(ax, "cumulative-residual")
4 changes: 3 additions & 1 deletion pymatviz/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def spacegroup_bar(

# sort df by crystal system going from smallest to largest spacegroup numbers
# e.g. triclinic (1-2) comes first, cubic (195-230) last
sys_order = dict(zip(crystal_sys_colors, range(len(crystal_sys_colors))))
sys_order = dict(
zip(crystal_sys_colors, range(len(crystal_sys_colors)), strict=True)
)
df_data = df_data.loc[
df_data[Key.crystal_system].map(sys_order).sort_values().index
]
Expand Down
10 changes: 5 additions & 5 deletions pymatviz/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from shutil import which
from time import sleep
from typing import TYPE_CHECKING, Any, Callable, Final, Literal
from typing import TYPE_CHECKING, Any, Final, Literal

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -24,7 +24,7 @@


if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from pathlib import Path

import pandas as pd
Expand Down Expand Up @@ -130,7 +130,7 @@ def save_fig(
if any(var in os.environ for var in env_disable):
return
# handle matplotlib figures
if isinstance(fig, (plt.Figure, plt.Axes)):
if isinstance(fig, plt.Figure | plt.Axes):
if hasattr(fig, "figure"):
fig = fig.figure # unwrap Axes
fig.savefig(path, **kwargs, transparent=True)
Expand Down Expand Up @@ -566,7 +566,7 @@ def print_table(
x0 = (1 - total_width) / 2
y_i = 1

for idx, (yd, row) in enumerate(zip(row_locs, rows)):
for idx, (yd, row) in enumerate(zip(row_locs, rows, strict=True)):
x_i = x0
y_i -= yd
# table zebra stripes
Expand All @@ -581,7 +581,7 @@ def print_table(
)
fig.add_artist(rect)

for xd, val in zip(col_widths, row):
for xd, val in zip(col_widths, row, strict=True):
text, weight, ha, bg_color, fg_color = val[:5]

if bg_color != row_colors[1]:
Expand Down
10 changes: 5 additions & 5 deletions pymatviz/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import sys
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Union, get_args, no_type_check
from typing import TYPE_CHECKING, Any, Literal, get_args, no_type_check

import plotly.express as px
import plotly.graph_objects as go
Expand All @@ -23,7 +23,7 @@
from pymatgen.core import Structure
from typing_extensions import Self

AnyBandStructure = Union[BandStructureSymmLine, PhononBands]
AnyBandStructure = BandStructureSymmLine | PhononBands


@dataclass
Expand Down Expand Up @@ -121,8 +121,8 @@ def get_band_xaxis_ticks(
return ticks_x_pos, tick_labels


YMin = Union[float, Literal["y_min"]]
YMax = Union[float, Literal["y_max"]]
YMin = float | Literal["y_min"]
YMax = float | Literal["y_max"]


@no_type_check
Expand All @@ -133,7 +133,7 @@ def _shaded_range(
return fig

shade_defaults = dict(layer="below", row="all", col="all")
y_lim = dict(zip(("y_min", "y_max"), fig.layout.yaxis.range))
y_lim = dict(zip(("y_min", "y_max"), fig.layout.yaxis.range, strict=True))

shaded_ys = shaded_ys or {(0, "y_min"): dict(fillcolor="gray", opacity=0.07)}
for (y0, y1), kwds in shaded_ys.items():
Expand Down
4 changes: 2 additions & 2 deletions pymatviz/powerups/both.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def annotate_metrics(
"""
if isinstance(metrics, str):
metrics = [metrics]
if not isinstance(metrics, (dict, list, tuple, set)):
if not isinstance(metrics, dict | list | tuple | set):
raise TypeError(
f"metrics must be dict|list|tuple|set, not {type(metrics).__name__}"
)
Expand Down Expand Up @@ -166,7 +166,7 @@ def add_identity_line(
"""
(x_min, x_max), (y_min, y_max) = get_fig_xy_range(fig=fig, trace_idx=trace_idx)

if isinstance(fig, (plt.Figure, plt.Axes)): # handle matplotlib
if isinstance(fig, plt.Figure | plt.Axes): # handle matplotlib
ax = fig if isinstance(fig, plt.Axes) else fig.gca()

line_defaults = dict(alpha=0.5, zorder=0, linestyle="dashed", color="black")
Expand Down
4 changes: 2 additions & 2 deletions pymatviz/powerups/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def annotate_bars(

y_max: float = 0
texts: list[Annotation] = []
for rect, label in zip(ax.patches, labels):
for rect, label in zip(ax.patches, labels, strict=True):
y_pos = rect.get_height()
x_pos = rect.get_x() + rect.get_width() / 2 + h_offset

Expand All @@ -113,7 +113,7 @@ def annotate_bars(

y_max = max(y_max, y_pos)

txt = f"{label:,}" if isinstance(label, (int, float)) else label
txt = f"{label:,}" if isinstance(label, int | float) else label
# place label at end of the bar and center horizontally
anno = ax.annotate(
txt, (x_pos, y_pos), ha="center", fontsize=fontsize, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion pymatviz/powerups/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def add_ecdf_line(

target_trace: BaseTraceType = fig.data[trace_idx]
if values is None or len(values) == 0:
if isinstance(target_trace, (go.Histogram, go.Scatter, go.Scattergl)):
if isinstance(target_trace, go.Histogram | go.Scatter | go.Scattergl):
values = target_trace.x
elif isinstance(target_trace, go.Bar):
xs, ys = target_trace.x, target_trace.y
Expand Down
16 changes: 9 additions & 7 deletions pymatviz/ptable/_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import warnings
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Literal, Union, get_args
from typing import TYPE_CHECKING, Literal, TypeAlias, get_args

import numpy as np
import pandas as pd
Expand All @@ -14,19 +14,21 @@


if TYPE_CHECKING:
from typing import Any, Callable
from collections.abc import Callable
from typing import Any

from numpy.typing import NDArray


# Data types that can be passed to PTableProjector and normalized by data_preprocessor
# to SupportedValueType
SupportedDataType = Union[
dict[str, Union[float, Sequence[float], np.ndarray]], pd.DataFrame, pd.Series
]
SupportedDataType: TypeAlias = (
dict[str, float | Sequence[float] | np.ndarray] | pd.DataFrame | pd.Series
)


# Data types used internally by ptable plotters (returned by preprocess_ptable_data)
SupportedValueType = Union[Sequence[float], np.ndarray]
SupportedValueType: TypeAlias = Sequence[float] | np.ndarray


class PTableData:
Expand Down Expand Up @@ -173,7 +175,7 @@ def _write_meta_data(self) -> None:
mean: The mean value.
vmax: The max value.
"""
numeric_values = pd.to_numeric(
numeric_values: pd.Series = pd.to_numeric(
self._data[self.val_col].explode().explode().explode(), errors="coerce"
)
self._data.attrs["vmin"] = numeric_values.min()
Expand Down
9 changes: 5 additions & 4 deletions pymatviz/ptable/_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@


if TYPE_CHECKING:
from typing import Any, Callable
from collections.abc import Callable
from typing import Any

import pandas as pd
from matplotlib.typing import ColorType
Expand Down Expand Up @@ -533,7 +534,7 @@ def rectangle(
tick_kwargs: For compatibility with other plotters.
"""
# Map values to colors
if isinstance(data, (Sequence, np.ndarray)):
if isinstance(data, Sequence | np.ndarray):
colors = [cmap(norm(value)) for value in data]
else:
raise TypeError("Unsupported data type.")
Expand Down Expand Up @@ -657,7 +658,7 @@ def histogram(
cols = (n - n.min()) / (n.max() - n.min())

# Apply colors
for col, patch in zip(cols, patches):
for col, patch in zip(cols, patches, strict=True):
plt.setp(patch, "facecolor", cmap(col))

# Set tick labels
Expand All @@ -683,7 +684,7 @@ def filter_near_zero(self, tol: float = 1e-6) -> None:

def to_scalar(x: float | list[float] | NDArray) -> float:
"""Convert single value array/list to scalar."""
if isinstance(x, (list, np.ndarray)):
if isinstance(x, list | np.ndarray):
return x[0] if len(x) > 0 else np.nan
return x

Expand Down
Loading
Loading