From 2ae84be957bbe14838b1b229db2446a0812c9a39 Mon Sep 17 00:00:00 2001 From: Brock Mendel Date: Tue, 7 Jan 2020 16:11:52 -0800 Subject: [PATCH 1/2] REF: PeriodIndex._union --- pandas/core/indexes/base.py | 14 +++++--------- pandas/core/indexes/category.py | 2 ++ pandas/core/indexes/period.py | 26 ++++++++++++++++++++------ 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 087db014de5b3..bfa6860457a59 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2310,14 +2310,11 @@ def _union(self, other, sort): return other._get_reconciled_name_object(self) # TODO(EA): setops-refactor, clean all this up - if is_period_dtype(self) or is_datetime64tz_dtype(self): - lvals = self._ndarray_values - else: - lvals = self._values - if is_period_dtype(other) or is_datetime64tz_dtype(other): - rvals = other._ndarray_values - else: - rvals = other._values + # PeriodIndex and DatetimeIndex both override _union + assert not (is_period_dtype(self.dtype) or is_datetime64tz_dtype(self.dtype)) + assert not (is_period_dtype(other.dtype) or is_datetime64tz_dtype(other.dtype)) + lvals = self._values + rvals = other._values if sort is None and self.is_monotonic and other.is_monotonic: try: @@ -2412,7 +2409,6 @@ def intersection(self, other, sort=False): other = other.astype("O") return this.intersection(other, sort=sort) - # TODO(EA): setops-refactor, clean all this up lvals = self._values rvals = other._values diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index f61721a0e51e6..0c709e562517b 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -388,6 +388,8 @@ def values(self): def _wrap_setop_result(self, other, result): name = get_op_result_name(self, other) + # We use _shallow_copy rather than the Index implementation + # (which uses _constructor) in order to preserve dtype. return self._shallow_copy(result, name=name) @Appender(_index_shared_docs["contains"] % _index_doc_kwargs) diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 1fed0201f7b2b..89ae665c36376 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -768,12 +768,6 @@ def _assert_can_do_setop(self, other): if isinstance(other, PeriodIndex) and self.freq != other.freq: raise raise_on_incompatible(self, other) - def _wrap_setop_result(self, other, result): - name = get_op_result_name(self, other) - result = self._apply_meta(result) - result.name = name - return result - def intersection(self, other, sort=False): self._validate_sort_keyword(sort) self._assert_can_do_setop(other) @@ -819,6 +813,26 @@ def difference(self, other, sort=None): result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) return result + def _union(self, other, sort): + if not len(other) or self.equals(other) or not len(self): + return super()._union(other, sort=sort) + + # We are called by `union`, which is responsible for this validation + assert isinstance(other, type(self)) + + if not is_dtype_equal(self.dtype, other.dtype): + this = self.astype("O") + other = other.astype("O") + return this._union(other, sort=sort) + + i8self = Int64Index._simple_new(self.asi8) + i8other = Int64Index._simple_new(other.asi8) + i8result = i8self._union(i8other, sort=sort) + + res_name = get_op_result_name(self, other) + result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) + return result + # ------------------------------------------------------------------------ def _apply_meta(self, rawarr): From 016d37a8b369690e4179cede583dd7a7af9ed5b1 Mon Sep 17 00:00:00 2001 From: Brock Mendel Date: Tue, 7 Jan 2020 16:50:13 -0800 Subject: [PATCH 2/2] update check --- pandas/core/indexes/base.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index bfa6860457a59..0ede340720922 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2310,11 +2310,14 @@ def _union(self, other, sort): return other._get_reconciled_name_object(self) # TODO(EA): setops-refactor, clean all this up - # PeriodIndex and DatetimeIndex both override _union - assert not (is_period_dtype(self.dtype) or is_datetime64tz_dtype(self.dtype)) - assert not (is_period_dtype(other.dtype) or is_datetime64tz_dtype(other.dtype)) - lvals = self._values - rvals = other._values + if is_datetime64tz_dtype(self): + lvals = self._ndarray_values + else: + lvals = self._values + if is_datetime64tz_dtype(other): + rvals = other._ndarray_values + else: + rvals = other._values if sort is None and self.is_monotonic and other.is_monotonic: try: @@ -2409,6 +2412,7 @@ def intersection(self, other, sort=False): other = other.astype("O") return this.intersection(other, sort=sort) + # TODO(EA): setops-refactor, clean all this up lvals = self._values rvals = other._values