Skip to content

Commit

Permalink
keep fixing compilation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Sep 30, 2024
1 parent 1938cd7 commit cba377b
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ void main(String[] args) throws LoadModelException, RunModelException {
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );
final Img< FloatType > img1 = imgFactory.create( 1, 1, 512, 512 );
Tensor<FloatType> inpTensor = Tensor.build("input0", "bcyx", img1);
List<Tensor<T>> inputs = new ArrayList<Tensor<T>>();
inputs.add((Tensor<T>) inpTensor);
List<Tensor<FloatType>> inputs = new ArrayList<Tensor<FloatType>>();
inputs.add(inpTensor);
final Img< FloatType > img2 = imgFactory.create( 1, 2, 512, 512 );
Tensor<FloatType> outTensor = Tensor.build("output0", "bcyx", img2);
List<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
outputs.add((Tensor<R>) outTensor);
List<Tensor<FloatType>> outputs = new ArrayList<Tensor<FloatType>>();
outputs.add(outTensor);
bi.run(inputs, outputs);
System.out.print(DECODE_JSON_VAL);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ public static void main(String[] args) {
void loadAndRunModel(String modelFolder, ModelDescriptor descriptor) throws Exception {
Model model = Model.createBioimageioModel(modelFolder, ENGINES_DIR);
model.loadModel();
List<Tensor<T>> inputs = createInputs(descriptor);
List<Tensor<R>> outputs = createOutputs(descriptor);
List<Tensor<FloatType>> inputs = createInputs(descriptor);
List<Tensor<FloatType>> outputs = createOutputs(descriptor);
model.runModel(inputs, outputs);
for (Tensor<?> tt : outputs) {
for (Tensor<FloatType> tt : outputs) {
if (tt.isEmpty())
throw new Exception(descriptor.getName() + ": Output tensor is empty");
}
Expand All @@ -147,8 +147,8 @@ void loadAndRunModel(String modelFolder, ModelDescriptor descriptor) throws Exce
* file containing the information
* @return the input Tensor list
*/
private static <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createInputs(ModelDescriptor descriptor) {
List<Tensor<T>> inputs = new ArrayList<Tensor<T>>();
private static List<Tensor<FloatType>> createInputs(ModelDescriptor descriptor) {
List<Tensor<FloatType>> inputs = new ArrayList<Tensor<FloatType>>();
final ImgFactory< FloatType > imgFactory = new ArrayImgFactory<>( new FloatType() );

for ( TensorSpec it : descriptor.getInputTensors()) {
Expand All @@ -159,7 +159,7 @@ private static <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createInp
long[] imSize = LongStream.range(0, step.length)
.map(i -> min[(int) i] + step[(int) i]).toArray();
Tensor<FloatType> tt = Tensor.build(name, axesStr, imgFactory.create(imSize));
inputs.add((Tensor<T>) tt);
inputs.add(tt);
}
return inputs;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ void loadAndRunTf1() throws LoadEngineException, Exception {
// Create the input tensor with the nameand axes given by the rdf.yaml file
// and add it to the list of input tensors
Tensor<FloatType> inpTensor = Tensor.build("input", "byxc", img1);
List<Tensor<T>> inputs = new ArrayList<Tensor<T>>();
inputs.add((Tensor<T>) inpTensor);
List<Tensor<FloatType>> inputs = new ArrayList<Tensor<FloatType>>();
inputs.add(inpTensor);

// Create the output tensors defined in the rdf.yaml file with their corresponding
// name and axes and add them to the output list of tensors.
Expand All @@ -212,8 +212,8 @@ void loadAndRunTf1() throws LoadEngineException, Exception {
// defining the dimensions and data type
final Img< FloatType > img2 = imgFactory.create( 1, 512, 512, 33 );
Tensor<FloatType> outTensor = Tensor.build("output", "byxc", img2);
List<Tensor<R>> outputs = new ArrayList<Tensor<R>>();
outputs.add((Tensor<R>) outTensor);
List<Tensor<FloatType>> outputs = new ArrayList<Tensor<FloatType>>();
outputs.add(outTensor);

// Run the model on the input tensors. THe output tensors
// will be rewritten with the result of the execution
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/io/bioimage/modelrunner/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;

/**
* Class that manages a Deep Learning model to load it and run it.
Expand Down Expand Up @@ -534,7 +535,7 @@ void runModel( List< Tensor < T > > inTensors, List< Tensor < R > > outTensors )
ArrayList<Tensor<FloatType>> inTensorsFloat = new ArrayList<Tensor<FloatType>>();
for (Tensor<T> tt : inTensors) {
if (tt.getData().getAt(0) instanceof FloatType)
inTensorsFloat.add((Tensor<FloatType>) tt);
inTensorsFloat.add(Cast.unchecked(tt));
else
inTensorsFloat.add(Tensor.createCopyOfTensorInWantedDataType( tt, new FloatType() ));
}
Expand Down Expand Up @@ -707,9 +708,9 @@ List<Tensor<T>> runTiling(List<Tensor<R>> inputTensors, TileMaker tiles, TilingC
public static <T extends NativeType<T> & RealType<T>> void main(String[] args) throws IOException, ModelSpecsException, LoadEngineException, RunModelException, LoadModelException {

String mm = "/home/carlos/git/JDLL/models/NucleiSegmentationBoundaryModel_17122023_143125";
Img<T> im = (Img<T>) ArrayImgs.floats(new long[] {1, 1, 512, 512});
Img<T> im = Cast.unchecked(ArrayImgs.floats(new long[] {1, 1, 512, 512}));
List<Tensor<T>> l = new ArrayList<Tensor<T>>();
l.add((Tensor<T>) Tensor.build("input0", "bcyx", im));
l.add(Tensor.build("input0", "bcyx", im));
Model model = createBioimageioModel(mm);
model.loadModel();
TileInfo tile = TileInfo.build(l.get(0).getName(), new long[] {1, 1, 512, 512},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,8 @@
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
Expand Down Expand Up @@ -227,9 +220,21 @@ private static long[][] getAllCombinations(long[] arr){
return allPoints;
}

@SuppressWarnings("unchecked")
public < R extends RealType< R > & NativeType< R > >
void scaleLinear(RandomAccessibleInterval<R> rai, double gain, double offset) {

R type = Util.getTypeFromInterval(rai);
if (type instanceof IntegerType) {
LoopBuilder.setImages( rai )
.multiThreaded()
.forEachPixel( i -> i.setReal(Math.floor(i.getRealDouble() * gain + offset) ) );
} else {
LoopBuilder.setImages( rai )
.multiThreaded()
.forEachPixel( i -> i.setReal((i.getRealDouble() * gain + offset) ) );
}
/**
* TODO remove
if (rai.getAt(0) instanceof ByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<ByteType>) rai )
.multiThreaded()
Expand Down Expand Up @@ -269,5 +274,6 @@ void scaleLinear(RandomAccessibleInterval<R> rai, double gain, double offset) {
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
*/
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,8 @@
import net.imglib2.img.basictypeaccess.array.FloatArray;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
Expand Down Expand Up @@ -146,14 +139,21 @@ private < R extends RealType< R > & NativeType< R > > void globalScale( final Te
scaleRange(output.getData(), maxPercentileVal, minPercentileVal);
}

@SuppressWarnings("unchecked")
private < R extends RealType< R > & NativeType< R > >
double findPercentileValue(RandomAccessibleInterval<R> rai, double percentile) {
final IterableInterval<R> flatImage = Views.iterable(rai);
long flatSize = Arrays.stream(flatImage.dimensionsAsLongArray()).reduce(1, (a, b) -> a * b);
double[] flatArr = new double[(int) flatSize];

int count = 0;
final Cursor<R> cursor = (Cursor<R>) flatImage.cursor();
while ( cursor.hasNext() )
{
cursor.next();
flatArr[count ++] = cursor.get().getRealDouble();
}
/*
* TODO remove
if (rai.getAt(0) instanceof ByteType) {
final Cursor<ByteType> cursor = (Cursor<ByteType>) flatImage.cursor();
while ( cursor.hasNext() )
Expand Down Expand Up @@ -220,6 +220,7 @@ private < R extends RealType< R > & NativeType< R > > void globalScale( final Te
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
*/
Arrays.sort(flatArr);

int percentilePos = (int) (flatSize * percentile);
Expand Down Expand Up @@ -306,10 +307,22 @@ public static void test2() {
System.out.print(true);
}

@SuppressWarnings("unchecked")
public < R extends RealType< R > & NativeType< R > >
void scaleRange(RandomAccessibleInterval<R> rai, double maxPercentileVal, double minPercentileVal) {
double diff = maxPercentileVal - minPercentileVal;

R type = Util.getTypeFromInterval(rai);
if (type instanceof IntegerType) {
LoopBuilder.setImages( rai )
.multiThreaded()
.forEachPixel( i -> i.setReal(Math.floor((i.getRealDouble() - minPercentileVal) / (diff + eps)) ) );
} else {
LoopBuilder.setImages( rai )
.multiThreaded()
.forEachPixel( i -> i.setReal(((i.getRealDouble() - minPercentileVal) / (diff + eps)) ) );
}
/**
* TODO remove
if (rai.getAt(0) instanceof ByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<ByteType>) rai )
.multiThreaded()
Expand Down Expand Up @@ -349,5 +362,6 @@ void scaleRange(RandomAccessibleInterval<R> rai, double maxPercentileVal, double
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
*/
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,8 @@
import net.imglib2.img.basictypeaccess.array.FloatArray;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.type.numeric.integer.LongType;
import net.imglib2.type.numeric.integer.ShortType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.integer.UnsignedIntType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
Expand Down Expand Up @@ -384,9 +377,20 @@ public static void test3() {
System.out.print(true);
}

@SuppressWarnings("unchecked")
public < R extends RealType< R > & NativeType< R > >
void zeroMeanUnitVariance(RandomAccessibleInterval<R> rai, double mean, double std) {
R type = Util.getTypeFromInterval(rai);
if (type instanceof IntegerType) {
LoopBuilder.setImages( rai )
.multiThreaded()
.forEachPixel( i -> i.setReal(Math.floor((i.getRealDouble() - mean) / (std + eps)) ) );
} else {
LoopBuilder.setImages( rai )
.multiThreaded()
.forEachPixel( i -> i.setReal(((i.getRealDouble() - mean) / (std + eps)) ) );
}
/**
* TODO remove
if (rai.getAt(0) instanceof ByteType) {
LoopBuilder.setImages( (RandomAccessibleInterval<ByteType>) rai )
.multiThreaded()
Expand Down Expand Up @@ -426,5 +430,6 @@ void zeroMeanUnitVariance(RandomAccessibleInterval<R> rai, double mean, double s
} else {
throw new IllegalArgumentException("Unsupported data type: " + Util.getTypeFromInterval(rai));
}
*/
}
}

0 comments on commit cba377b

Please sign in to comment.