#!/usr/bin/env python
"""
test_range_finder.py: Tests the functions defined in rangeFinder.py.
__author__ = "pankajrsingla"
"""
import os
import sys
import matplotlib
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
os.environ['USE_PYGEOS'] = '0'
import geopandas as gpd
import folium
import networkx as nx
import pytest
# Add the parent directory to sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
# Import everything from rangeFinder.py
from rangeFinder import *
[docs]
def test_get_colors():
"""
Test the 'get_colors' function to ensure it generates color lists correctly.
"""
# Define a list of predefined matplotlib colors
mcolors = ['#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', \
'#7f7f7f', '#bcbd22', '#17becf', '#1f77b4', '#ff7f0e']
# Test getting a subset of colors (5 colors)
colors_subset = get_colors(5)
assert colors_subset == mcolors[:5]
# Test getting all available colors (10 colors)
colors_all = get_colors(10)
assert colors_all == mcolors
# Test getting more colors than available (12 colors)
colors_repeat = get_colors(12)
assert colors_repeat == mcolors + mcolors[:2]
[docs]
def test_plot_initialization():
"""
Test the initialization of the Plot class
"""
# Create a Figure and Axes object for testing
fig = matplotlib.pyplot.figure()
ax = fig.add_subplot(111)
# Initialize a Plot object with the created Figure and Axes
plot = Plot(fig=fig, ax=ax)
# Check if the fig and ax attributes are set correctly
assert plot.fig is fig
assert plot.ax is ax
[docs]
def test_origin_initialization():
"""
Test case to check the initialization of the Origin class
"""
# Initialize an Origin object with an ID and coordinates
origin = Origin(id=1, lat=48.152519, long=11.582104)
# Check if the attributes are set correctly
assert origin.id == 1
assert origin.lat == 48.152519
assert origin.long == 11.582104
[docs]
def test_point_initialization():
"""
Test case to check the initialization of the Point class
"""
# Initialize a Point object with latitude, longitude, length, and mode
lat = 123.456
long = 23.45
length = 789
mode = "walk"
point_type = "hydrant"
point = Point(lat=lat, long=long, length=length,
mode=mode, point_type=point_type)
# Check if the attributes are set correctly
assert point.lat == lat
assert point.long == long
assert point.length == length
assert point.mode == mode
assert point.point_type == point_type
assert point.graph is None
assert point.origin is None
assert point.gdf is None
assert point.plot is None
assert point.points_in_range == []
assert point.interactive is None
assert point.debug is False
[docs]
def test_point_is_in_state():
"""
Test the 'is_in_state' function to check if a point lies in a state.
"""
# Point is in the specified state(s)
point_munich = Point(lat=48.15251938407116, long=11.582103781931604)
assert point_munich.is_in_state("Bayern") is True
point_delhi = Point(28.63005551713948, 77.21671040205324)
assert point_delhi.is_in_state("Delhi") is True
# Point is not in the specified state
assert point_munich.is_in_state("Berlin") is False
[docs]
def test_calculate_elevation():
"""
Test the 'calculate_elevation' function in the Point class.
"""
# Initialize a RangeFinder object
point = Point(lat=40.78397967404166, long=-73.96314049062211) # New York
point.calculate_elevation()
# Check if elevation is correct
assert int(point.elevation) == 34
[docs]
def test_point_get_range_graph():
"""
Test the 'get_range_graph' function in the Point class.
"""
# Initialize a Point object for testing
point = Point(lat=48.152519, long=11.582104, length=500, mode="drive")
# Check if valid graph has been created
point.get_range_graph()
assert point.graph is not None
# Invalid graph creation should raise an exception
# Note: Currently the get_range_graph function raises the exception.
# Uncomment this block when the exception has been gracefully handled.
# lat_invalid = 1000.0 # Invalid latitude value
# point_invalid = Point(lat=lat_invalid, long=11.582104, length=500, mode="drive")
# with pytest.raises(Exception) as exc_info:
# point_invalid.get_range_graph()
# assert "No valid network graph can be created" in str(exc_info.value)
[docs]
def test_point_get_points_in_range():
"""
Test the 'get_points_in_range' function in the Point class.
"""
# Initialize a Point object for testing.
length = 1000
point = Point(lat=48.152519, long=11.582104, length=length,
mode="drive", point_type="fire")
point.get_range_graph()
point.get_origin()
# Calculate points within range for the given Point object.
point.get_points_in_range()
# Check if number of graph nodes is same as number of points in range.
assert len(point.points_in_range) == len(point.graph.nodes())
# Check if each node has information in the correct format.
for node_data in point.points_in_range:
assert isinstance(node_data, list)
assert len(node_data) == 4
# Get distances from the points_in_range list.
distances = [node_data[3] for node_data in point.points_in_range]
# Assert that the minimum distance is 0 (origin node to itself).
assert min(distances) == 0
# The maximum distance should be less than or equal to the specified length.
assert max(distances) <= length
[docs]
def test_point_get_origin():
"""
Test the 'get_origin' function in the Point class.
"""
# Initialize a Point object for testing
lat = 48.152519
long = 11.582104
point = Point(lat=lat, long=long, length=400, mode="walk")
point.get_range_graph()
# Get the origin for the given Point object
point.get_origin()
# Check if 'origin' attribute is an instance of Origin
assert isinstance(point.origin, Origin)
# Compute origin coordinates directly, and compare
origin_id = ox.nearest_nodes(point.graph, X=point.long, Y=point.lat,
return_dist=False)
origin = Origin(origin_id)
origin.lat = point.graph.nodes[origin_id]["y"]
origin.long = point.graph.nodes[origin_id]["x"]
# Check if 'lat' and 'long' attributes of origin are set correctly
assert point.origin.lat == pytest.approx(origin.lat, abs=1e-6)
assert point.origin.long == pytest.approx(origin.long, abs=1e-6)
[docs]
def test_point_get_gdf():
"""
Test the 'get_gdf' function in the Point class.
"""
point = Point(lat=48.15251938407116, long=11.582103781931604,
length=400, mode="drive", point_type="fire")
point.get_range_graph()
# Generate GeoDataFrame for the given Point object
point.get_gdf()
# Check if 'gdf' attribute is set, is a GeoDataFrame, and has some rows
assert point.gdf is not None
assert isinstance(point.gdf, gpd.GeoDataFrame)
assert len(point.gdf) > 0
[docs]
def test_point_plot_network_graph():
"""
Test the 'plot_network_graph' function in the Point class.
"""
# Initialize a Point object for testing
point = Point(lat=28.63005551713948, long=77.21671040205324,
length=500, mode="drive")
point.get_range_graph()
# Plot the network graph for the given Point object
point.plot_network_graph()
# Check if 'plot' attribute is set and is an instance of Plot
assert point.plot is not None
assert isinstance(point.plot, Plot)
# Check that the plot has valid fig and ax objects
assert isinstance(point.plot.fig, matplotlib.figure.Figure)
assert isinstance(point.plot.ax, matplotlib.axes._axes.Axes)
[docs]
def test_point_annotate_node_distances_from_centre():
"""
Test the 'annotate_node_distances_from_centre' function in the Point class.
"""
# Initialize a Point object for testing
point = Point(lat=48.152519, long=11.582104, length=480, mode="drive")
point.get_range_graph()
# Plot the network graph for the given Point object
point.plot_network_graph()
# Check if 'plot' attribute is set and is an instance of Plot
assert point.plot is not None
assert isinstance(point.plot, Plot)
# Annotate the node distances
point.annotate_node_distances_from_centre()
# Check if there are at least annotations on the plot
annotations = point.plot.ax.texts
assert len(annotations) > 0
for annotation in annotations:
assert isinstance(annotation, matplotlib.text.Annotation)
[docs]
def test_point_annotate_origin_coordinates():
"""
Test the 'annotate_origin_coordinates' function in the Point class.
"""
# Initialize a Point object for testing
point = Point(lat=48.152519, long=11.582104, length=450, mode="drive")
point.get_range_graph()
point.plot_network_graph()
# Annotate the origin point's coordinates on the plot
point.annotate_origin_coordinates()
# Check if there are annotations on the plot (at least some)
annotations = point.plot.ax.texts
assert len(annotations) > 0
# Find the annotation that corresponds to the origin coordinates
origin_annotation = None
for annotation in annotations:
if annotation.get_text() == f"({point.origin.long:.3f}, {point.origin.lat:.3f})":
origin_annotation = annotation
break
# Assert that the origin annotation is found and has the correct attributes
assert origin_annotation is not None
assert isinstance(origin_annotation, matplotlib.text.Annotation)
assert origin_annotation.xy == (point.origin.long, point.origin.lat)
assert origin_annotation.get_ha() == 'center'
assert origin_annotation.get_fontsize() == 10
assert origin_annotation.get_color() == "white"
[docs]
def test_point_annotate_street_names():
"""
Test the 'annotate_street_names' function in the Point class.
"""
# Initialize a Point object for testing
point = Point(lat=51.509552216027714, long=-0.1284403947353646, length=1000, mode="walk")
point.plot_network_graph()
# Annotate street names on the plot
point.annotate_street_names()
# Check if there are annotations on the plot (at least some)
annotations = point.plot.ax.texts
assert len(annotations) > 0
# If repeat_street_names is set, there should be more annotations.
point_copy = Point(lat=51.509552216027714, long=-0.1284403947353646, length=1000, mode="walk")
point_copy.plot_network_graph()
point_copy.annotate_street_names(repeat_street_names=True)
additional_annotations = point_copy.plot.ax.texts
assert len(additional_annotations) > len(annotations)
[docs]
def test_point_get_plot():
"""
Test the 'get_plot' function in the Point class.
"""
# Call the test cases for the functions invoked by get_plot
test_point_plot_network_graph()
test_point_annotate_node_distances_from_centre()
test_point_annotate_origin_coordinates()
test_point_annotate_street_names()
[docs]
def test_point_get_interactive_map():
"""
Test the 'get_interactive_map' function in the Point class.
"""
# Initialize a Point object for testing
point = Point(lat=-35.30273276420382, long=149.12567284494077, length=500, mode="drive")
# Test marker icons
point.get_interactive_map(default_style="OpenStreetMap", edge_color="red")
# Check if the map has a valid name
assert isinstance(point.interactive.get_name(), str) # Check for map name
assert "map_" in point.interactive.get_name() # Map name starts with "map_"
# Check that there is a tile layer and a marker
all_children_value_types = [type(value) for value in list(point.interactive._children.values())]
assert folium.raster_layers.TileLayer in all_children_value_types
assert folium.map.Marker in all_children_value_types
# Check if the generated HTML map contains the expected elements
interactive_html = point.interactive.get_root().render()
assert "html" in interactive_html
assert "head" in interactive_html
assert "body" in interactive_html
assert "leaflet" in interactive_html
# Test if tooltip contains correct information
point_text = f"Point: ({point.lat},{point.long}), Hose length: {int(point.length)}, Mode: {point.mode}"
assert point_text in point.interactive.get_root().render()
[docs]
def test_rangefinder_initialization():
"""
Test the constructor for the RangeFinder class.
"""
# Initialize a RangeFinder object
range_finder = RangeFinder()
# Check if the show_elevations attribute is False
assert range_finder.show_elevations == False
# Check if the 'points' attribute is an empty list
assert isinstance(range_finder.points, list)
assert len(range_finder.points) == 0
# Check if the 'merged_gdf' attribute is set to None
assert range_finder.merged_gdf is None
# Check if the 'plots' attribute is an empty list
assert isinstance(range_finder.plots, list)
assert len(range_finder.plots) == 0
# Check if the 'merged_interactive' attribute is set to None
assert range_finder.merged_interactive is None
[docs]
def test_rangefinder_add_points():
"""
Test the 'add_points' function in the RangeFinder class.
"""
# Initialize a RangeFinder object
rf = RangeFinder()
# Create a sample input DataFrame with 4 points
input_data = {
"latitude": [48.1525, 48.1536, 48.1547, 48.1558],
"longitude": [11.5821, 11.5832, 11.5843, 11.5854],
"hose_length": [800, 600, 500, 400],
"transportation_mode": ["drive", "drive_service", "bike", "walk"],
"point_type": ["fire", "hydrant", "hydrant", "fire"]
}
input_df = pd.DataFrame(input_data)
# Add points to the RangeFinder
rf.add_points(input_df)
# Check if points have been added to the RangeFinder
assert len(rf.points) == 4
# Check if the added points have the correct attributes
for point in rf.points:
assert isinstance(point, Point)
assert isinstance(point.graph, nx.MultiDiGraph)
assert isinstance(point.origin.id, int)
assert isinstance(point.gdf, gpd.GeoDataFrame)
[docs]
def test_rangefinder_get_plots():
"""
Test the 'get_plots' function in the RangeFinder class.
"""
# Initialize a RangeFinder object
range_finder = RangeFinder()
# Create sample points and add them to the RangeFinder
points_data = [
{"latitude": 52.5187, "longitude": 13.3780, "hose_length": 450, "transportation_mode": "walk"},
{"latitude": 52.5817, "longitude": 13.3890, "hose_length": 550, "transportation_mode": "bike"},
{"latitude": 52.5781, "longitude": 13.3970, "hose_length": 650, "transportation_mode": "drive_service"},
{"latitude": 52.5718, "longitude": 13.3980, "hose_length": 750, "transportation_mode": "drive"}
]
for point_data in points_data:
point = Point(point_data["latitude"], point_data["longitude"],
point_data["hose_length"], point_data["transportation_mode"])
range_finder.points.append(point)
# Generate plots for the points using the get_plots method
range_finder.get_plots()
# Check if plots have been added to the RangeFinder
assert len(range_finder.plots) == 4
# Check if the added plots are instances of the Plot class
for plot in range_finder.plots:
assert isinstance(plot, Plot)
# Check that the plot has valid fig and ax objects
assert isinstance(plot.fig, matplotlib.figure.Figure)
assert isinstance(plot.ax, matplotlib.axes._axes.Axes)
[docs]
def test_calculate_elevations():
"""
Test the 'calculate_elevations' function in the RangeFinder class.
"""
# Initialize a RangeFinder object
rf = RangeFinder()
# Create a sample input DataFrame with 4 points
input_data = {
"latitude": [48.15251938407116, 28.63005551713948],
"longitude": [11.582103781931604, 77.21671040205324],
"hose_length": [800, 600],
"transportation_mode": ["drive", "drive_service"],
"point_type": ["fire", "hydrant"]
}
input_df = pd.DataFrame(input_data)
# Add points to the RangeFinder
rf.add_points(input_df)
rf.show_elevations = True
rf.calculate_elevations()
# Check if elevations were calculated and stored correctly
assert int(rf.points[0].elevation) == 517
assert int(rf.points[1].elevation) == 218
[docs]
def test_rangefinder_add_edge_colors():
"""
Test the 'add_edge_colors' function in the RangeFinder class.
"""
# Initialize a RangeFinder object
rf = RangeFinder()
# Create a sample input DataFrame with 4 points
input_data = {
"latitude": [28.63005551713948, 40.78397967404166],
"longitude": [77.21671040205324, -73.96314049062211],
"hose_length": [300, 400],
"transportation_mode": ["walk", "bike"],
"point_type": ["fire", "fire"]
}
input_df = pd.DataFrame(input_data)
# Add points to the RangeFinder
rf.add_points(input_df)
rf.merged_gdf = pd.concat([point.gdf for point in rf.points],
ignore_index=True)
assert rf.merged_interactive is None
rf.add_edge_colors()
# Interactive map has been generated
assert rf.merged_interactive is not None
# All element types in the interactive map
all_children_vals = [str(value) for value in list(rf.merged_interactive._children.values())]
# At this stage, the map should only have a TileLayer two FeatureGroups fr the two graphs
assert len(all_children_vals) == 3 #3 because the two overlap
assert "folium.raster_layers.TileLayer" in all_children_vals[0]
assert "folium.map.FeatureGroup" in all_children_vals[1]
assert "folium.map.FeatureGroup" in all_children_vals[2]
[docs]
def test_rangefinder_add_markers_to_map():
"""
Test the 'add_markers_to_map' function in the RangeFinder class.
"""
# Initialize a RangeFinder object
rf = RangeFinder()
# Create a sample input DataFrame with 4 points
input_data = {
"latitude": [37.99231816755858, -35.30273276420382],
"longitude": [23.732845405626346, 149.12567284494077],
"hose_length": [500, 600],
"transportation_mode": ["bike", "drive"],
"point_type": ["hydrant", "fire"]
}
input_df = pd.DataFrame(input_data)
# Add points to the RangeFinder
rf.add_points(input_df)
rf.merged_gdf = pd.concat([point.gdf for point in rf.points],
ignore_index=True)
rf.add_edge_colors()
rf.add_markers_to_map()
# n_points can be different than the number of points we specified above
# in input_data, since not all points are guaranteed to have a valid graph.
n_points = len(rf.points)
# All element types in the interactive map
all_children_vals = [str(value) for value in list(rf.merged_interactive._children.values())]
# Check that the interactive map has n circles for the n origin points
assert sum("folium.vector_layers.Circle" in value for value in all_children_vals) == n_points
# Check that the map has n markers for the n points
assert sum("folium.map.Marker" in value for value in all_children_vals) == n_points
# For n points, there should be (n * n-1) // 2 lines in the map
assert sum("folium.vector_layers.PolyLine" in value for value in all_children_vals) == \
n_points * (n_points - 1) // 2
[docs]
def test_rangefinder_create_interactive_map():
# Initialize a RangeFinder object
rf = RangeFinder()
# Create a sample input DataFrame with 2 points
input_data = {
"latitude": [50.87409910796766, 50.87803070660611],
"longitude": [4.697961746722167, 4.681290298545937],
"hose_length": [600, 300],
"transportation_mode": ["drive", "walk"],
"point_type": ["fire", "hydrant"]
}
input_df = pd.DataFrame(input_data)
# Add points to the RangeFinder
rf.add_points(input_df)
# Test if the method returns a Folium Map object
rf.create_interactive_map()
interactive_map = rf.merged_interactive
assert isinstance(rf.merged_interactive, folium.Map)
# Test if merged GeoDataFrame is created
assert isinstance(rf.merged_gdf, gpd.GeoDataFrame)
# Test if unique colors are generated correctly
unique_colors = get_colors(len(rf.points))
assert len(unique_colors) == len(rf.points)
# Test if edge colors are generated correctly
edge_colors = ([unique_colors[i]] * len(point.gdf) for i, point in enumerate(rf.points))
edge_colors = [color for color_list in edge_colors for color in color_list]
assert len(edge_colors) == len(rf.merged_gdf)
# Check if the map has a valid name
assert isinstance(interactive_map.get_name(), str) # Check for map name
assert "map_" in interactive_map.get_name() # Map name starts with "map_"
# Check that there is a tile layer and a marker
all_children_value_types = [type(value) for value in list(interactive_map._children.values())]
assert folium.raster_layers.TileLayer in all_children_value_types
assert folium.map.Marker in all_children_value_types