select_bucket
The select_bucket
function is responsible for determining the appropriate resolution (size) for an input image, based on predefined resolutions and aspect ratios, as well as the maximum area allowed for the image. The function takes two arguments: image_width
and image_height
, which represent the dimensions of the input image.
Here’s a breakdown of the function’s logic:
It calculates the aspect ratio of the input image by dividing
image_width
byimage_height
.If
no_upscale
is False, the function attempts to find a predefined resolution that matches or closely approximates the aspect ratio of the input image. Here’s how it works:a. If the input image’s resolution is already in the set of predefined resolutions (
self.predefined_resos_set
), it leaves the resolution unchanged.b. If not, it calculates the difference between the input aspect ratio and each of the predefined aspect ratios (
self.predefined_aspect_ratios
). It then selects the predefined resolution with the smallest aspect ratio difference (predefined_bucket_id
).c. Based on the selected predefined resolution (
reso
), it determines the scale factor to resize the input image. If the input aspect ratio is greater than the predefined aspect ratio, it scales to match the height of the predefined resolution. Otherwise, it scales to match the width.d. The resized dimensions (
resized_size
) are calculated by multiplying the input dimensions by the scale factor and rounding them to the nearest integer.If
no_upscale
is True, the function performs only reduction (downscaling) of the image:a. If the input image’s area (
image_width * image_height
) exceeds the maximum allowed area (self.max_area
), it calculates the resized dimensions (resized_width
andresized_height
) by preserving the aspect ratio while ensuring the area does not exceedself.max_area
.b. It then rounds the resized dimensions to the nearest multiple of
self.reso_steps
(a predefined resolution step size). This is done by choosing the rounding that results in the aspect ratio closest to the original aspect ratio.c. The final resized dimensions (
resized_size
) are set to the rounded values.d. If the input image’s area does not exceed
self.max_area
, the function setsresized_size
to the original input dimensions, as no resizing is needed.e. The function then calculates the bucket resolution (
reso
) by finding the largest dimensions that are multiples ofself.reso_steps
and smaller than or equal to the resized dimensions.The function adds the selected resolution (
reso
) to a set of unique resolutions (self.add_if_new_reso(reso)
).Finally, it calculates the error between the aspect ratio of the selected resolution and the original aspect ratio (
ar_error
), and returns the tuple(reso, resized_size, ar_error)
.
The get_crop_ltrb
function, which is a static method, calculates the left, top, right, and bottom coordinates for cropping an image to fit a given bucket resolution. It takes two arguments: bucket_reso
(the target resolution) and image_size
(the original image size). The function determines whether to match the height or width of the bucket resolution based on the aspect ratios, calculates the resized dimensions, and then computes the crop coordinates based on the difference between the bucket resolution and the resized dimensions.
Translation
def select_bucket(self, image_width, image_height):
aspect_ratio = image_width / image_height
if not self.no_upscale:
# Perform enlargement and reduction
# Prioritize the same resolution because there might be the same aspect ratio
# (in case of preprocessing with no_upscale=True in fine tuning)
reso = (image_width, image_height)
if reso in self.predefined_resos_set:
pass
else:
ar_errors = self.predefined_aspect_ratios - aspect_ratio
predefined_bucket_id = np.abs(ar_errors).argmin() # The one with the least aspect ratio error other than the resolution
reso = self.predefined_resos[predefined_bucket_id]
ar_reso = reso[0] / reso[1]
if aspect_ratio > ar_reso: # If the width is longer, match the height
scale = reso[1] / image_height
else:
scale = reso[0] / image_width
resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5))
# logger.info(f"use predef, {image_width}, {image_height}, {reso}, {resized_size}")
else:
# Only perform reduction
if image_width * image_height > self.max_area:
# If the image is too large, decide the bucket on the premise of reducing it while keeping the aspect ratio
resized_width = math.sqrt(self.max_area * aspect_ratio)
resized_height = self.max_area / resized_width
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
# Make the short side or long side after resizing a multiple of reso_steps: choose the one with less aspect ratio difference
# Same logic as the original bucketing
b_width_rounded = self.round_to_steps(resized_width)
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
ar_width_rounded = b_width_rounded / b_height_in_wr
b_height_rounded = self.round_to_steps(resized_height)
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
ar_height_rounded = b_width_in_hr / b_height_rounded
# logger.info(b_width_rounded, b_height_in_wr, ar_width_rounded)
# logger.info(b_width_in_hr, b_height_rounded, ar_height_rounded)
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5))
else:
resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded)
# logger.info(resized_size)
else:
resized_size = (image_width, image_height) # No need to resize
# Make the size of the image less than the size of the bucket (crop without padding)
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
# logger.info(f"use arbitrary {image_width}, {image_height}, {resized_size}, {bucket_width}, {bucket_height}")
reso = (bucket_width, bucket_height)
self.add_if_new_reso(reso)
ar_error = (reso[0] / reso[1]) - aspect_ratio
return reso, resized_size, ar_error
@staticmethod
def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]):
# Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation.
bucket_ar = bucket_reso[0] / bucket_reso[1]
image_ar = image_size[0] / image_size[1]
if bucket_ar > image_ar:
# If the bucket is wider, match the height
resized_width = bucket_reso[1] * image_ar
resized_height = bucket_reso[1]
else:
resized_width = bucket_reso[0]
resized_height = bucket_reso[0] / image_ar
crop_left = (bucket_reso[0] - resized_width) // 2
crop_top = (bucket_reso[1] - resized_height) // 2
crop_right = crop_left + resized_width
crop_bottom = crop_top + resized_height
return crop_left, crop_top, crop_right, crop_bottom
class BucketBatchIndex(NamedTuple):
bucket_index: int
bucket_batch_size: int
batch_index: int