Skip to content

Instantly share code, notes, and snippets.

@iydon
Created February 15, 2023 09:09
Show Gist options
  • Save iydon/0a079d572127a01fef74c0a48e180ccb to your computer and use it in GitHub Desktop.
Save iydon/0a079d572127a01fef74c0a48e180ccb to your computer and use it in GitHub Desktop.
import collections as c
import copy
import random
import typing as t
if t.TYPE_CHECKING:
from typing_extensions import Self
Category = t.Hashable
Count = int
Index = int
class Group:
'''
Example:
```
cats = [
1, 1, 1, 2, 3, 1, 1, 2, 1,
3, 2, 3, 3, 4, 1, 3, 3, 3,
1, 2, 1, 2, 2, 2, 2, 2, 2,
1, 4, 3, 2, 4, 4, 4, 4, 4,
]
group = Group(cats)
```
'''
def __init__(self, cats: t.Iterable[Category]) -> None:
self._cats = cats
@classmethod
def random(cls, x: int = 99, y: int = 9, z: float = 0.8) -> t.List[Category]:
cats = [
[
ith if random.random() < z else random.choice(range(y))
for _ in range(x)
] for ith in range(y)
]
random.shuffle(cats)
return sum(cats, start=[])
@property
def cats(self) -> t.Iterable[Category]:
return self._cats
def copy(self) -> 'Self':
return self.__class__(copy.deepcopy(self._cats))
def indices(self, threshold: Count = 2) -> t.Iterator[Index]:
'''
Example:
>>> for idx in group.indices(threshold=2):
... print(idx)
9
18
27
'''
old, new = c.defaultdict(Count), c.defaultdict(Count)
prev_idx, prev_cnt = 0, 0
for ith, cat in enumerate(self._cats):
old[cat] += 1
new[cat] += 1
if cat != self._top(old):
if prev_cnt == 0:
new.clear()
new[cat] = 1
prev_idx = ith
prev_cnt += 1
else:
prev_cnt = 0
if prev_cnt > threshold:
old, prev_cnt = new.copy(), 0
yield prev_idx
def slices(self, threshold: Count = 2) -> t.Iterable[slice]:
'''
Example:
>>> for idx in group.slices(threshold=2):
... print(cats[idx])
[1, 1, 1, 2, 3, 1, 1, 2, 1]
[3, 2, 3, 3, 4, 1, 3, 3, 3]
[1, 2, 1, 2, 2, 2, 2, 2, 2]
[1, 4, 3, 2, 4, 4, 4, 4, 4]
'''
previous, index = 0, 0
for index in self.indices(threshold=threshold):
yield slice(previous, index)
previous = index
yield slice(index, None)
def _top(self, counter: t.Dict[Category, Count]) -> Category:
return max(counter.keys(), key=counter.__getitem__)
if __name__ == '__main__':
import matplotlib.pyplot as plt
x, y, z, n = 100, 10, 0.7, 5
cats = Group.random(x, y, z)
xs = list(range(len(cats)))
group = Group(cats).copy()
# figure
fig, ax = plt.subplots(1, 1, figsize=(12, 9), dpi=144)
# ax.set_prop_cycle('color', plt.cm.rainbow(np.linspace(0, 1, y)))
for idx in group.slices(threshold=n):
label = f'{idx.start:03} - {(idx.stop or len(xs))-1:03}'
ax.scatter(xs[idx], cats[idx], s=1, label=label)
ax.set_xlabel('Index')
ax.set_ylabel('Category')
ax.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')
fig.savefig('image.jpg', bbox_inches='tight')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment