diff --git a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py index d5716d37..d55dd3d0 100644 --- a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py +++ b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py @@ -5,6 +5,9 @@ from modules.ui_components import InputAccordion import modules.scripts as scripts from modules.torch_utils import float64 +from concurrent.futures import ThreadPoolExecutor +from scipy.ndimage import convolve +from joblib import Parallel, delayed, cpu_count class SoftInpaintingSettings: def __init__(self, @@ -244,7 +247,76 @@ def apply_masks( return masks_for_overlay -def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + + +def weighted_histogram_filter_single_pixel(idx, img, kernel, kernel_center, percentile_min, percentile_max, min_width): + """ + Apply the weighted histogram filter to a single pixel. + This function is now refactored to be accessible for parallelization. + """ + idx = np.array(idx) + kernel_min = -kernel_center + kernel_max = np.array(kernel.shape) - kernel_center + + # Precompute the minimum and maximum valid indices for the kernel + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(np.array(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + # Initialize values and weights arrays + values = [] + weights = [] + + for window_tup in np.ndindex(*window_shape): + window_index = np.array(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + values.append(img[tuple(image_index)]) + weights.append(kernel[tuple(kernel_index)]) + + # Convert to NumPy arrays + values = np.array(values) + weights = np.array(weights) + + # Sort values and weights by values + sorted_indices = np.argsort(values) + values = values[sorted_indices] + weights = weights[sorted_indices] + + # Calculate cumulative weights + cumulative_weights = np.cumsum(weights) + + # Define window boundaries + sum_weights = cumulative_weights[-1] + window_min = sum_weights * percentile_min + window_max = sum_weights * percentile_max + window_width = window_max - window_min + + # Ensure window is at least `min_width` wide + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum_weights: + window_max = sum_weights + window_min = sum_weights - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + # Calculate overlap for each value + overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1]))) + overlap_end = np.minimum(window_max, cumulative_weights) + overlap = np.maximum(0, overlap_end - overlap_start) + + # Weighted average calculation + result = np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0 + return result + +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0, n_jobs=-1): """ Generalization convolution filter capable of applying weighted mean, median, maximum, and minimum filters @@ -271,101 +343,74 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe (nparray): A filtered copy of the input image "img", a 2-D array of floats. """ - # Converts an index tuple into a vector. - def vec(x): - return np.array(x) - - kernel_min = -kernel_center - kernel_max = vec(kernel.shape) - kernel_center + # Ensure kernel_center is a 1D array + if isinstance(kernel_center, int): + kernel_center = np.array([kernel_center, kernel_center]) + elif len(kernel_center) == 1: + kernel_center = np.array([kernel_center[0], kernel_center[0]]) + kernel_radius = max(kernel_center) + padded_img = np.pad(img, kernel_radius, mode='constant', constant_values=0) + img_out = np.zeros_like(img) + img_shape = img.shape + pixel_coords = [(i, j) for i in range(img_shape[0]) for j in range(img_shape[1])] def weighted_histogram_filter_single(idx): - idx = vec(idx) - min_index = np.maximum(0, idx + kernel_min) - max_index = np.minimum(vec(img.shape), idx + kernel_max) - window_shape = max_index - min_index + """ + Single-pixel weighted histogram calculation. + """ + row, col = idx + idx = (row + kernel_radius, col + kernel_radius) + min_index = np.array(idx) - kernel_center + max_index = min_index + kernel.shape - class WeightedElement: - """ - An element of the histogram, its weight - and bounds. - """ + window = padded_img[min_index[0]:max_index[0], min_index[1]:max_index[1]] + window_values = window.flatten() + window_weights = kernel.flatten() - def __init__(self, value, weight): - self.value: float = value - self.weight: float = weight - self.window_min: float = 0.0 - self.window_max: float = 1.0 + sorted_indices = np.argsort(window_values) + values = window_values[sorted_indices] + weights = window_weights[sorted_indices] - # Collect the values in the image as WeightedElements, - # weighted by their corresponding kernel values. - values = [] - for window_tup in np.ndindex(tuple(window_shape)): - window_index = vec(window_tup) - image_index = window_index + min_index - centered_kernel_index = image_index - idx - kernel_index = centered_kernel_index + kernel_center - element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) - values.append(element) + cumulative_weights = np.cumsum(weights) + sum_weights = cumulative_weights[-1] + window_min = max(0, sum_weights * percentile_min) + window_max = min(sum_weights, sum_weights * percentile_max) - def sort_key(x: WeightedElement): - return x.value - - values.sort(key=sort_key) - - # Calculate the height of the stack (sum) - # and each sample's range they occupy in the stack - sum = 0 - for i in range(len(values)): - values[i].window_min = sum - sum += values[i].weight - values[i].window_max = sum - - # Calculate what range of this stack ("window") - # we want to get the weighted average across. - window_min = sum * percentile_min - window_max = sum * percentile_max window_width = window_max - window_min - - # Ensure the window is within the stack and at least a certain size. if window_width < min_width: window_center = (window_min + window_max) / 2 - window_min = window_center - min_width / 2 - window_max = window_center + min_width / 2 + window_min = max(0, window_center - min_width / 2) + window_max = min(sum_weights, window_center + min_width / 2) - if window_max > sum: - window_max = sum - window_min = sum - min_width + overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1]))) + overlap_end = np.minimum(window_max, cumulative_weights) + overlap = np.maximum(0, overlap_end - overlap_start) - if window_min < 0: - window_min = 0 - window_max = min_width + return np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0 - value = 0 - value_weight = 0 + # Split pixel_coords into equal chunks based on n_jobs + n_jobs = -1 + if cpu_count() > 6: + n_jobs = 6 # More than 6 isn't worth unless it's more than 3000x3000px - # Get the weighted average of all the samples - # that overlap with the window, weighted - # by the size of their overlap. - for i in range(len(values)): - if window_min >= values[i].window_max: - continue - if window_max <= values[i].window_min: - break + chunk_size = len(pixel_coords) // n_jobs + pixel_chunks = [pixel_coords[i:i + chunk_size] for i in range(0, len(pixel_coords), chunk_size)] - s = max(window_min, values[i].window_min) - e = min(window_max, values[i].window_max) - w = e - s + # joblib to process chunks in parallel + def process_chunk(chunk): + chunk_result = {} + for idx in chunk: + chunk_result[idx] = weighted_histogram_filter_single(idx) + return chunk_result - value += values[i].value * w - value_weight += w + results = Parallel(n_jobs=n_jobs, backend="loky")( # loky is fastest in my configuration + delayed(process_chunk)(chunk) for chunk in pixel_chunks + ) - return value / value_weight if value_weight != 0 else 0 - - img_out = img.copy() - - # Apply the kernel operation over each pixel. - for index in np.ndindex(img.shape): - img_out[index] = weighted_histogram_filter_single(index) + # Combine results into the output image + for chunk_result in results: + for (row, col), value in chunk_result.items(): + img_out[row, col] = value return img_out @@ -485,7 +530,7 @@ el_ids = SoftInpaintingSettings( class Script(scripts.Script): def __init__(self): - # self.section = "inpaint" + self.section = "inpaint" self.masks_for_overlay = None self.overlay_images = None