ImageRescaler
class ImageRescaler(Preprocessor):
def __init__(self, target_size=(32, 32), original_size=(512, 512), flatten=False):
super().__init__()
self.target_size = target_size
self.original_size = original_size
self.flatten = flatten
def resize_image(self, image: np.ndarray) -> np.ndarray:
# Check if the image needs to be reshaped from 1D to 3D for RGB images
if image.ndim == 1:
image = image.reshape((self.original_size[0], self.original_size[1], 3))
# Convert to PIL Image, resize, and convert back to numpy
pil_image = Image.fromarray(image.astype(np.uint8))
resized_image = pil_image.resize(self.target_size)
resized_array = np.array(resized_image)
# Flatten if specified
if self.flatten:
if resized_array.ndim == 3: # RGB image
resized_array = resized_array.reshape(-1)
elif resized_array.ndim == 2: # Grayscale image
resized_array = resized_array.flatten()
return resized_array
def resize_all_images(self, dataset: Dataset) -> pd.DataFrame:
resized_imgs = []
for _, row in dataset.data.iterrows():
rgb_pixels = row["data"]
resized_image = self.resize_image(rgb_pixels)
resized_imgs.append(
{"title": row["title"], "data": resized_image, "target": row["target"]}
)
dataframe = pd.DataFrame(resized_imgs)
dataframe["title"] = dataframe["title"].astype("string")
if not dataset.label_numeric:
dataframe["target"] = dataframe["target"].astype("string")
return dataframe
def apply(self, dataset: Dataset) -> Dataset:
resized_data = self.resize_all_images(dataset)
dataset.data = resized_data
return dataset
Last updated