import io import os import sys import time import multiprocessing from concurrent.futures import ProcessPoolExecutor, as_completed import geopandas as gpd import laspy import networkx as nx import numpy as np import rasterio from rasterio.crs import CRS from rasterio.transform import from_origin from scipy.spatial import cKDTree from shapely.geometry import LineString, Point from tqdm import tqdm # --- Acceleration Libraries --- import cupy as cp import cupyx.scipy.ndimage as ndimage from skimage.morphology import skeletonize, disk # Ensure UTF-8 output sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') # ------------------------------------------------------------------------- # 1. VECTORIZED GRAPH BUILDER # ------------------------------------------------------------------------- def skeleton_to_graph_vectorized(skeleton, transform, resolution): rows, cols = np.where(skeleton) h, w = skeleton.shape id_grid = np.full((h, w), -1, dtype=np.int32) node_ids = np.arange(len(rows)) id_grid[rows, cols] = node_ids edges = [] weights = [] shifts = [(0, 1, 1.0), (1, 0, 1.0), (1, 1, 1.414), (1, -1, 1.414)] for dr, dc, w_mult in shifts: n_rows = rows + dr n_cols = cols + dc valid = (n_rows >= 0) & (n_rows < h) & (n_cols >= 0) & (n_cols < w) valid_indices = np.where(valid)[0] src_ids = node_ids[valid_indices] tgt_ids = id_grid[n_rows[valid_indices], n_cols[valid_indices]] mask = tgt_ids != -1 if np.any(mask): final_src = src_ids[mask] final_tgt = tgt_ids[mask] current_edges = np.column_stack((final_src, final_tgt)) edges.append(current_edges) dist = resolution * w_mult weights.append(np.full(len(final_src), dist)) G = nx.Graph() for idx, r, c in zip(node_ids, rows, cols): G.add_node(idx, r=r, c=c) if edges: all_edges = np.vstack(edges) all_weights = np.concatenate(weights) edge_tuples = zip(all_edges[:,0], all_edges[:,1], all_weights) G.add_weighted_edges_from(edge_tuples) return G # ------------------------------------------------------------------------- # 2. TOPOLOGY CLEANING & FILTERING # ------------------------------------------------------------------------- def simplify_skeleton_graph(G, min_component_size=5, height=0, width=0, thresh_px=0): # Anchor any nodes within the pixel threshold from the boundary boundary_nodes = { n for n in G.nodes if G.nodes[n]['r'] <= thresh_px or G.nodes[n]['r'] >= height - 1 - thresh_px or G.nodes[n]['c'] <= thresh_px or G.nodes[n]['c'] >= width - 1 - thresh_px } components = list(nx.connected_components(G)) for comp in components: if len(comp) < min_component_size: # Only delete noise if no part of it is within the boundary threshold if not any(node in boundary_nodes for node in comp): G.remove_nodes_from(comp) max_cycle_len = 20 for _ in range(5): try: cycles = nx.minimum_cycle_basis(G) except: break merged_any = False for cycle in cycles: if len(cycle) < max_cycle_len: if any(n in boundary_nodes for n in cycle): continue rows = [G.nodes[n]['r'] for n in cycle] cols = [G.nodes[n]['c'] for n in cycle] avg_r = sum(rows) / len(rows) avg_c = sum(cols) / len(cols) target = cycle[0] G.nodes[target]['r'] = avg_r G.nodes[target]['c'] = avg_c for node in cycle[1:]: if node in G: nx.contracted_nodes(G, target, node, self_loops=False, copy=False) merged_any = True if not merged_any: break while True: junctions = [n for n, d in G.degree() if d > 2 and n not in boundary_nodes] if not junctions: break G_sub = G.subgraph(junctions) comps = list(nx.connected_components(G_sub)) merged_any = False for comp in comps: if len(comp) > 1: comp_list = list(comp) target = comp_list[0] rows = [G.nodes[n]['r'] for n in comp_list] cols = [G.nodes[n]['c'] for n in comp_list] G.nodes[target]['r'] = sum(rows) / len(rows) G.nodes[target]['c'] = sum(cols) / len(cols) for node in comp_list[1:]: if node in G: nx.contracted_nodes(G, target, node, self_loops=False, copy=False) merged_any = True if not merged_any: break return G # ------------------------------------------------------------------------- # 3. GPU WORKER FOR EXISTING CENTERLINES # ------------------------------------------------------------------------- def process_centerline_laz_worker(args): start_time = time.time() input_path, out_laz, out_tif, out_shp_nodes, out_shp_edges, resolution, closing_radius, boundary_threshold, epsg = args filename = os.path.basename(input_path) result = {"success": False, "filename": filename, "message": "", "duration": 0.0} try: cp.cuda.Device(0).use() las = laspy.read(input_path) if len(las.points) == 0: result["message"] = "Skipped: Empty file." return result points_xyz = np.vstack((las.x, las.y, las.z)).transpose() gpu_points = cp.asarray(points_xyz[:, :2]) header_min_x, header_min_y = las.header.mins[0], las.header.mins[1] header_max_x, header_max_y = las.header.maxs[0], las.header.maxs[1] min_x = np.floor(header_min_x / resolution) * resolution min_y = np.floor(header_min_y / resolution) * resolution max_x = np.ceil(header_max_x / resolution) * resolution max_y = np.ceil(header_max_y / resolution) * resolution width = int(np.ceil((max_x - min_x) / resolution)) height = int(np.ceil((max_y - min_y) / resolution)) idx_x = ((gpu_points[:, 0] - min_x) / resolution).astype(cp.int32) idx_y = ((max_y - gpu_points[:, 1]) / resolution).astype(cp.int32) idx_x = cp.clip(idx_x, 0, width - 1) idx_y = cp.clip(idx_y, 0, height - 1) gpu_grid = cp.zeros((height, width), dtype=cp.uint8) gpu_grid[idx_y, idx_x] = 1 pad = closing_radius gpu_grid_padded = cp.pad(gpu_grid, pad_width=pad, mode='edge') selem = cp.asarray(disk(closing_radius)) gpu_closed_padded = ndimage.binary_closing(gpu_grid_padded, structure=selem) gpu_closed = gpu_closed_padded[pad:-pad, pad:-pad] closed_cpu = cp.asnumpy(gpu_closed) cp.get_default_memory_pool().free_all_blocks() skeleton = skeletonize(closed_cpu, method='lee') if not np.any(skeleton): result["message"] = "Skipped: Could not build skeleton." return result transform = from_origin(float(min_x), float(max_y), resolution, resolution) # Calculate threshold in pixels to protect nodes during graph simplification thresh_px = int(np.ceil(boundary_threshold / resolution)) G = skeleton_to_graph_vectorized(skeleton, transform, resolution) G = simplify_skeleton_graph(G, min_component_size=10, height=height, width=width, thresh_px=thresh_px) base_critical = [n for n, d in G.degree() if d != 2] critical_nodes = list(set(base_critical)) if not critical_nodes and len(G.nodes) > 0: critical_nodes = [list(G.nodes)[0]] tree = cKDTree(points_xyz[:, :2]) def get_pt(r, c): x, y = rasterio.transform.xy(transform, r, c, offset='center') return x, y exact_node_coords = {} final_nodes_data = [] # SNAP & FLAG LOGIC: Evaluate true geometric distance for n in critical_nodes: r, c = G.nodes[n]['r'], G.nodes[n]['c'] x, y = get_pt(r, c) is_bridge = 0 # Calculate absolute distance to mathematical borders dist_left = x - min_x dist_right = max_x - x dist_top = max_y - y dist_bottom = y - min_y # If any distance is strictly within the user-defined threshold, it's a bridge! if min(dist_left, dist_right, dist_top, dist_bottom) <= boundary_threshold: is_bridge = 1 # Snap exactly to the mathematical grid edge if dist_left <= boundary_threshold: x = min_x elif dist_right <= boundary_threshold: x = max_x if dist_top <= boundary_threshold: y = max_y elif dist_bottom <= boundary_threshold: y = min_y _, idx = tree.query([x, y], k=1) z = points_xyz[idx, 2] exact_node_coords[n] = (x, y, z) final_nodes_data.append({ 'geometry': Point(x, y, z), 'is_bridge': is_bridge, 'node_id': int(n) }) final_edges = [] visited_edges = set() for start_node in critical_nodes: for neighbor in G.neighbors(start_node): edge_id = tuple(sorted((start_node, neighbor))) if edge_id in visited_edges: continue visited_edges.add(edge_id) path_nodes = [start_node, neighbor] curr, prev = neighbor, start_node while G.degree(curr) == 2 and curr not in critical_nodes: nbrs = list(G.neighbors(curr)) if len(nbrs) == 1: break next_node = nbrs[0] if nbrs[0] != prev else nbrs[1] visited_edges.add(tuple(sorted((curr, next_node)))) path_nodes.append(next_node) prev, curr = curr, next_node line_coords = [] for n_path in path_nodes: if n_path in exact_node_coords: line_coords.append(exact_node_coords[n_path]) else: px, py = get_pt(G.nodes[n_path]['r'], G.nodes[n_path]['c']) _, idx = tree.query([px, py], k=1) pz = points_xyz[idx, 2] line_coords.append((px, py, pz)) clean = [] if line_coords: clean.append(line_coords[0]) for i in range(1, len(line_coords)): if np.linalg.norm(np.array(clean[-1]) - np.array(line_coords[i])) > 0.001: clean.append(line_coords[i]) if len(clean) >= 2: final_edges.append(LineString(clean)) # --- E. Save Outputs --- if final_nodes_data: gdf_nodes = gpd.GeoDataFrame(final_nodes_data, crs=f"EPSG:{epsg}") gdf_nodes.to_file(out_shp_nodes) if final_edges: gpd.GeoDataFrame(geometry=final_edges, crs=f"EPSG:{epsg}").to_file(out_shp_edges) with rasterio.open( out_tif, 'w', driver='GTiff', height=height, width=width, count=1, dtype=rasterio.uint8, crs=CRS.from_epsg(epsg), transform=transform, compress='lzw', nodata=0 ) as dst: dst.write(skeleton.astype(rasterio.uint8), 1) all_pts = [] for line in final_edges: all_pts.extend(line.coords) if all_pts: pts_arr = np.array(all_pts) hdr = laspy.LasHeader(point_format=3, version="1.2") hdr.scales = [0.001, 0.001, 0.001] hdr.offsets = [float(min_x), float(max_y), np.min(pts_arr[:,2])] out = laspy.LasData(hdr) out.x, out.y, out.z = pts_arr[:,0], pts_arr[:,1], pts_arr[:,2] out.write(out_laz) result["success"] = True result["duration"] = time.time() - start_time return result except Exception as e: result["message"] = f"Error: {str(e)}" result["duration"] = time.time() - start_time return result # ------------------------------------------------------------------------- # 4. MAIN CONTROLLER # ------------------------------------------------------------------------- def batch_process_centerlines(input_folder, output_folder, resolution, closing_radius, boundary_threshold, epsg): total_start = time.time() os.makedirs(output_folder, exist_ok=True) dirs = {k: os.path.join(output_folder, k) for k in ["tif", "laz", "shp_nodes", "shp_edges"]} for d in dirs.values(): os.makedirs(d, exist_ok=True) log_file = os.path.join(output_folder, "processing_log.txt") files = [f for f in os.listdir(input_folder) if f.lower().endswith(('.laz', '.las'))] max_workers = min(16, len(files)) print(f"\n{'='*70}") print(f"Hardware-Optimized Graph Extraction - SNAP THRESHOLD ALIGNED") print(f"Files: {len(files)} | Concurrent Workers: {max_workers}") print(f"Bridging Radius: {closing_radius} pixels") print(f"Boundary Snap Threshold: {boundary_threshold} meters") print(f"{'='*70}\n") tasks = [] for file in files: base = os.path.splitext(file)[0] tasks.append(( os.path.join(input_folder, file), os.path.join(dirs["laz"], f"{base}_graph_clean.laz"), os.path.join(dirs["tif"], f"{base}_graph.tif"), os.path.join(dirs["shp_nodes"], f"{base}_nodes.shp"), os.path.join(dirs["shp_edges"], f"{base}_edges.shp"), resolution, closing_radius, boundary_threshold, epsg )) success_count = 0 fail_count = 0 with open(log_file, 'w') as f: f.write(f"Start: {time.ctime()}\n{'='*50}\n") with ProcessPoolExecutor(max_workers=max_workers) as executor: future_to_file = {executor.submit(process_centerline_laz_worker, t): t[0] for t in tasks} with tqdm(total=len(files), unit="file") as pbar: for future in as_completed(future_to_file): res = future.result() if res["success"]: success_count += 1 else: fail_count += 1 status = "SUCCESS" if res["success"] else "FAILED" msg = f"[{status}] {res['filename']} ({res['duration']:.2f}s)" if not res["success"]: msg += f" - {res['message']}" with open(log_file, 'a') as f: f.write(msg + "\n") pbar.set_postfix(last=f"{res['filename'][:10]}..", t=f"{res['duration']:.1f}s") pbar.update(1) print(f"\nDone in {time.time() - total_start:.2f}s. Success: {success_count}, Failed: {fail_count}") if __name__ == "__main__": multiprocessing.set_start_method('spawn', force=True) IN_DIR = "1_Extracted_Centerlines_Merged" OUT_DIR = "2_Final_Graph_Outputs_V_1_1_3_5.1" EPSG = 28992 RES = 0.1 RADIUS = 4 # Distance in METERS. Any endpoint within this range of a boundary gets snapped and flagged. BOUNDARY_THRESH = 0.2 batch_process_centerlines(IN_DIR, OUT_DIR, RES, RADIUS, BOUNDARY_THRESH, EPSG)