Skip to content

Neural network prediction tasks

UNet prediction task

plantseg.tasks.prediction_tasks.unet_prediction_task(image: PlantSegImage, model_name: str | None, model_id: str | None, suffix: str = '_prediction', patch: tuple[int, int, int] | None = None, patch_halo: tuple[int, int, int] | None = None, single_batch_mode: bool = True, device: str = 'cuda', model_update: bool = False, disable_tqdm: bool = False, config_path: Path | None = None, model_weights_path: Path | None = None) -> list[PlantSegImage]

Apply a trained U-Net model to a PlantSegImage object.

Parameters:

  • image (PlantSegImage) –

    input image object

  • model_name (str) –

    the name of the model to use

  • model_id (str) –

    the ID of the model to use

  • suffix (str, default: '_prediction' ) –

    suffix to append to the new image name

  • patch (tuple[int, int, int], default: None ) –

    patch size for prediction

  • single_batch_mode (bool, default: True ) –

    whether to use a single batch for prediction

  • device (str, default: 'cuda' ) –

    the computation device ('cpu', 'cuda', etc.)

  • model_update (bool, default: False ) –

    whether to update the model to the latest version

Source code in plantseg/tasks/prediction_tasks.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@task_tracker
def unet_prediction_task(
    image: PlantSegImage,
    model_name: str | None,
    model_id: str | None,
    suffix: str = "_prediction",
    patch: tuple[int, int, int] | None = None,
    patch_halo: tuple[int, int, int] | None = None,
    single_batch_mode: bool = True,
    device: str = "cuda",
    model_update: bool = False,
    disable_tqdm: bool = False,
    config_path: Path | None = None,
    model_weights_path: Path | None = None,
) -> list[PlantSegImage]:
    """
    Apply a trained U-Net model to a PlantSegImage object.

    Args:
        image (PlantSegImage): input image object
        model_name (str): the name of the model to use
        model_id (str): the ID of the model to use
        suffix (str): suffix to append to the new image name
        patch (tuple[int, int, int]): patch size for prediction
        single_batch_mode (bool): whether to use a single batch for prediction
        device (str): the computation device ('cpu', 'cuda', etc.)
        model_update (bool): whether to update the model to the latest version
    """
    data = image.get_data()
    input_layout = image.image_layout

    pmaps = unet_prediction(
        raw=data,
        input_layout=input_layout.value,
        model_name=model_name,
        model_id=model_id,
        patch=patch,
        patch_halo=patch_halo,
        single_batch_mode=single_batch_mode,
        device=device,
        model_update=model_update,
        disable_tqdm=disable_tqdm,
        config_path=config_path,
        model_weights_path=model_weights_path,
    )
    assert pmaps.ndim == 4, f"Expected 4D CZXY prediction, got {pmaps.ndim}D"

    new_images = []

    for i, pmap in enumerate(pmaps):
        # Input layout is always ZYX this loop
        pmap = fix_layout(pmap, input_layout=ImageLayout.ZYX.value, output_layout=input_layout.value)
        new_images.append(
            image.derive_new(
                pmap,
                name=f"{image.name}_{suffix}_{i}",
                semantic_type=SemanticType.PREDICTION,
                image_layout=input_layout,
            )
        )

    return new_images