matplotlib 设置时间轴格式,仅显示特定日期之后的刻度标签

u5i3ibmn  于 2023-10-24  发布在  其他
关注(0)|答案(1)|浏览(114)

我正在尝试格式化子图中的时间轴,以便xtick标签仅出现在数据开始之后。

import pandas as pd
import numpy as np
from datetime import date

import matplotlib.pyplot as plt
from matplotlib.dates import MonthLocator, num2date
from matplotlib.ticker import FuncFormatter

xs = [pd.date_range(f'{y}-07-01', '2021-12-31', freq='M')
      for y in range(2016, 2019)]
ys = [np.random.rand(len(x)) for x in xs]

fig, axs = plt.subplots(3, 1, figsize=(10, 8))
for ax, x, y in zip(axs, xs, ys):
    ax.plot(x, y)

    # Custom formatting function
    date_min = x[0].date()
    def custom_date_formatter(x, pos):
        dt = num2date(x)
        if dt.date() < date_min:
            return ''
        elif dt.month == 1:
            return dt.strftime('%Y')
        else:
            return dt.strftime('%b').upper()

    ax.xaxis.set_major_locator(MonthLocator((1, 4, 7, 10)))
    ax.xaxis.set_major_formatter(FuncFormatter(custom_date_formatter))
    ax.tick_params(axis='x', labelsize=8)

    ax.set_xlim(date(2016, 7, 1))

但我得到了一个所有轴都显示相同xtick标签的图:

但我希望他们是:

JUL - OCT - 2017 - APR - JUL - OCT - 2018 .....

..... JUL - OCT - 2018 - APR - JUL - OCT - 2019 .....

............ JUL - OCT - 2019 - APR - JUL - OCT - 2020 .....

我应该如何修复custom_date_formatter函数来实现这一点?

gudnpqoy

gudnpqoy1#

您遇到的问题是由于在循环中定义了函数,并且与所谓的closure有关

例如,考虑以下简化但基本相同的算法情况:

functions = []
for i in range(3):
    def print_i():
        print(i)
    functions.append(print_i)

for f in functions:
    f()

它给出了以下可能意外的输出(就像你的问题一样):

2
2
2

由于函数在循环的每次迭代中都被重新定义,当附加列表中的所有三个都被调用时,只有函数的最后一个定义保持定义-因此输出是相同的,而不是0, 1, 2

修改代码,将custom_date_formatter函数移出subplots axes循环

from datetime import date
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from matplotlib.dates import MonthLocator
from matplotlib.dates import num2date
from matplotlib.ticker import FuncFormatter

def generate_data(start_year, end_year):
    """Generate random data between given years.
    
    Args:
        start_year (int): The starting year for the date range.
        end_year (int): The ending year for the date range.
    
    Returns:
        tuple: A tuple containing two lists:
               - A list of pandas date_range objects from June of the 
                start_year to December 2021.
               - A list of numpy arrays containing random float values for 
                each date in the respective date_range.
    """
    date_ranges = [
        pd.date_range(f"{y}-06-01", "2021-12-31", freq="M")
        for y in range(start_year, end_year)
    ]
    data_values = [np.random.rand(len(dr)) for dr in date_ranges]

    return date_ranges, data_values

def custom_date_formatter(x, date_min):
    """Custom date formatter for the x-axis.
    
    Args:
        x (float): The raw x-coordinate value from the plot. Represents a date 
                   in the matplotlib's float format.
        date_min (datetime.date): The minimum date to be displayed on the 
                   x-axis.
    
    Returns:
        str: Formatted date string. If the date is less than date_min, an 
             empty string is returned. If the date corresponds to January, the 
             year is returned. Otherwise, the month's abbreviation is 
             returned.
    """
    dt = num2date(x)
    if dt.date() < date_min:
        return ""
    elif dt.month == 1:
        return dt.strftime("%Y")
    else:
        return dt.strftime("%b").upper()

def plot_data(date_ranges, data_values):
    """Plot data with custom x-axis formatting.
    
    Args:
        date_ranges (list): A list of pandas date_range objects to be plotted 
                            on the x-axis.
        data_values (list): A list of numpy arrays containing the y-values for 
                            each date_range.
    """
    fig, axs = plt.subplots(len(date_ranges), 1, figsize=(10, 8))

    for ax, dates, values in zip(axs, date_ranges, data_values):
        ax.plot(dates, values)

        bound_formatter = partial(
            lambda x, pos: custom_date_formatter(x, dates[0].date())
        )
        ax.xaxis.set_major_locator(MonthLocator((1, 4, 7, 10)))
        ax.xaxis.set_major_formatter(FuncFormatter(bound_formatter))

        # Code to remove x ticks with blank labels:
        ticks = ax.get_xticks()
        labels = [bound_formatter(tick, None) for tick in ticks]
        update_ticks = [
            tick for tick, label in zip(ticks, labels) if label != ""
        ]
        ax.set_xticks(update_ticks)

        ax.tick_params(axis="x", labelsize=8)
        ax.set_xlim(date(2016, 7, 1))

    plt.tight_layout()
    plt.show()

# Generate data and plot
date_ranges, data_values = generate_data(2016, 2019)
plot_data(date_ranges, data_values)

相关问题