BerGeo/h2/util.py

203 lines
5.3 KiB
Python

import random
from collections import namedtuple, defaultdict
from enum import Enum, auto
from typing import Set
from math import cos, sin, sqrt, pi
import matplotlib.pyplot as plt
import numpy as np
Point = namedtuple('Point', 'x y')
Vector = namedtuple('Vector', 'x y')
def gen_point(lower: int = 0, upper: int = 10) -> Point:
a = random.uniform(lower, upper)
b = random.uniform(lower, upper)
#x_i = random.uniform(lower, upper)
#p_i = Point(x_i, a * x_i + b)
return Point(a, b)
def display(points: Set[Point], hull: Set[Point]):
x = [point.x for point in points]
y = [point.y for point in points]
h_x = [point.x for point in hull]
h_y = [point.y for point in hull]
plt.plot(h_x, h_y, 'ro')
plt.scatter(x, y)
plt.show()
def gen_circular_point(lower : int = 0, upper: int = 10, radius: int = 5) -> Point:
a = random.uniform(lower, upper) * 2 * pi
r = radius * sqrt(random.uniform(lower, upper))
x = r * cos(a)
y = r * sin(a)
return Point(x,y)
def gen_weird_point(lower : int = 0, upper: int = 10) -> Point:
x = random.uniform(lower, upper)
y = x**2
if x < 0:
return Point(random.uniform(x, -x), y)
return Point(random.uniform(-x, x), y)
def read_and_prep_data(filename):
data = open(filename).read()
lines = data.split('\n')
data = defaultdict(list)
for line in lines[1:]:
all_vars = line.split("\t\t")
name, points, time = all_vars
data[name.strip()].append([points, time[:8]])
return data
def gen_graph(data):
graham = data['graham']
quick = data['quick']
mbch = data['mbch']
mbch2 = data['mbch2']
gift = data['gift']
graham_x = [p[0] for p in graham]
graham_y = [p[1] for p in graham]
quick_x = [p[0] for p in quick]
quick_y = [p[1] for p in quick]
mbch_x = [p[0] for p in mbch]
mbch_y = [p[1] for p in mbch]
mbch2_x = [p[0] for p in mbch2]
mbch2_y = [p[1] for p in mbch2]
gift_x = [p[0] for p in gift]
gift_y = [p[1] for p in gift]
plt.plot(graham_x, graham_y)
plt.plot(quick_x, quick_y)
plt.plot(mbch_x, mbch_y)
plt.plot(mbch2_x, mbch2_y)
plt.plot(gift_x, gift_y)
plt.legend(['graham', 'quick', 'mbch', 'mbch2', 'gift'], loc='upper left')
plt.show()
def gen_triangular_point(left : Point, right : Point, top : Point):
r1 = random.uniform(0,1)
r2 = random.uniform(0,1)
return Point((1 - sqrt(r1)) * left.x + (sqrt(r1) * (1 - r2)) * right.x + (sqrt(r1) * r2) * top.x,
(1 - sqrt(r1)) * left.y + (sqrt(r1) * (1 - r2)) * right.y + (sqrt(r1) * r2) * top.y)
def display_line_only(points: Set[Point], slope: float, intercept: float, line_points: Set[Point]):
x = [point.x for point in points]
y = [point.y for point in points]
plt.scatter(x, y)
# Plot a line from slope and intercept
axes = plt.gca()
x_vals = np.array(axes.get_xlim())
y_vals = intercept + slope * x_vals
for point in line_points:
plt.plot(point.x, point.y, 'go')
plt.plot(x_vals, y_vals, '--')
plt.show()
class Side(Enum):
ON = auto()
ABOVE = auto()
BELOW = auto()
def stacked_bar(ax, data, series_labels, category_labels=None,
show_values=True, value_format="{}", y_label=None,
grid=False, reverse=False):
"""
Plots a stacked bar chart with the data and labels provided (https://stackoverflow.com/a/50205834).
Keyword arguments:
data -- 2-dimensional numpy array or nested list
containing data for each series in rows
series_labels -- list of series labels (these appear in
the legend)
category_labels -- list of category labels (these appear
on the x-axis)
show_values -- If True then numeric value labels will
be shown on each bar
value_format -- Format string for numeric value labels
(default is "{}")
y_label -- Label for y-axis (str)
grid -- If True display grid
reverse -- If True reverse the order that the
series are displayed (left-to-right
or right-to-left)
"""
ny = len(data[0])
ind = list(range(ny))
axes = []
cum_size = np.zeros(ny)
data = np.array(data)
if reverse:
data = np.flip(data, axis=1)
category_labels = reversed(category_labels)
for i, row_data in enumerate(data):
axes.append(ax.bar(ind, row_data, bottom=cum_size,
label=series_labels[i]))
cum_size += row_data
if category_labels:
plt.sca(ax)
plt.xticks(ind, category_labels)
if y_label:
plt.ylabel(y_label)
ax.legend()
# Reverse legend (https://stackoverflow.com/a/34576778)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1])
if grid:
ax.grid()
if show_values:
for axis in axes:
for bar in axis:
w, h = bar.get_width(), bar.get_height()
if h != 0:
ax.text(bar.get_x() + w/2, bar.get_y() + h/2,
value_format.format(h), ha="center",
va="center")
#plt.show()