Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,109 +20,120 @@
public class ParameterUtils {

/**
* This method is used by {@link StateNode#fromXML(Node)} for a {@link Scalar},
* and has to be consistent with the method {@link StateNode#toString()} for creating XML.
* This method is used by {@link StateNode#fromXML(Node)} to restore a parameter from its
* serialized state-file string, and must stay consistent with {@link #paramToString(StateNode)}.
* <p>
* For example, <code>kappa{[0.0,Infinity]}: 29</code>,
* or <code>freqs{4, [0.0,1.0]}: 0.25 0.25 0.25 0.25</code>.
* If no bounds, then <code>isEstimated: true</code>, or <code>isSelected{2}: true false</code>.
* In BEAST3, bounds are derived from the parameter's domain and are never written to the
* state file. The expected format is therefore always bound-free:
* <ul>
* <li>scalar: {@code kappa: 29}</li>
* <li>vector: {@code freqs{4}: 0.25 0.25 0.25 0.25}</li>
* <li>boolean scalar: {@code isEstimated: true}</li>
* <li>boolean vector: {@code isSelected{2}: true false}</li>
* </ul>
* A state file entry that still contains explicit bounds (BEAST2 legacy format such as
* {@code kappa{[0.0,Infinity]}: 29}) is rejected with {@link IllegalArgumentException}.
*
* @param node XML node
* @param param a parameter which is also a {@link StateNode}
* @param node XML node whose text content is the serialized parameter string
* @param param the target {@link StateNode} to restore
* @throws IllegalArgumentException if the string contains legacy explicit bounds
* @throws RuntimeException if the string format is unrecognised
*/
public static void parseParameter(final Node node, StateNode param) {

final NamedNodeMap atts = node.getAttributes();
// set ID from XML
param.setID(atts.getNamedItem("id").getNodeValue());
final String str = node.getTextContent();

// need to sync with toString
Pattern boundedPattern = Pattern.compile("^.*" + // id
// shape is optional, empty for scalar, or size of vector, or [3, 4] for matrix
// Explicit bounds in state files are a BEAST2 legacy format.
// In BEAST3, bounds are derived from the domain (see BoundedParam removal).
// Fail fast so the user knows to restart rather than resume from such a file.
Pattern boundedPattern = Pattern.compile("^.*" +
"\\{" + "(?:(\\d+|\\[\\d+,\\s*\\d+\\]),\\s*)?" +
"[\\[\\(](.*),(.*)[\\]\\)]" + "\\}" + // bounds
":\\s*(.*)\\s*$"); // value(s)
Matcher matcher1 = boundedPattern.matcher(str);
"[\\[\\(](.*),(.*)[\\]\\)]" + "\\}" +
":\\s*(.*)\\s*$");
if (boundedPattern.matcher(str).matches()) {
throw new IllegalArgumentException(
"XML file entry '" + str + "' contains explicit bounds, which are not " +
"supported in BEAST3. Bounds are now derived from the parameter domain; " +
"values can be constrained further using a prior distribution.");
}

Pattern noboundPattern = Pattern.compile("^.*" + // id
// shape is optional, empty for scalar, or size of vector, or [3, 4] for matrix
// All BEAST3 parameter types serialize without explicit bounds.
// The non-greedy prefix (.*?) ensures the optional {shape} group is captured
// for vector types (e.g. "freqs{4}: 0.25 ..."), and the non-greedy suffix (.*?)
// lets the trailing \s* absorb any whitespace the vector loop appends.
// Format: id{shape}: value(s) — {shape} is absent for scalars.
Pattern noboundPattern = Pattern.compile("^.*?" +
"(?:\\{(\\d+|\\[\\d+,\\s*\\d+\\])\\})?" +
":\\s*(.*)\\s*$"); // value(s)
Matcher matcher2 = noboundPattern.matcher(str);

// scalar and vector use shape for validation, but it is compulsory for matrix.
":\\s*(.*?)\\s*$");
Matcher matcher = noboundPattern.matcher(str);

if (matcher1.matches()) { // with bounds
// id is already assigned
final String shape = matcher1.group(1);
final String lower = matcher1.group(2);
final String upper = matcher1.group(3);
final String valuesAsString = matcher1.group(4);
if (matcher.matches()) {
final String shape = matcher.group(1); // null for scalars
final String valuesAsString = matcher.group(2);
final String[] valuesStr = valuesAsString.split("\\s+");

final String[] valuesStr = valuesAsString.split(" ");
if (param instanceof RealScalarParam<?> realScalarParam) {
realScalarParam.fromXML(shape, valuesStr);
} else if (param instanceof IntScalarParam<?> intScalarParam) {
intScalarParam.fromXML(shape, valuesStr);
// } else if (param instanceof BoundedParam<?> boundedParam) { //TODO
// boundedParam.fromXML(lower, upper, shape, valuesStr);
} else
throw new RuntimeException("Unknown parameter type : " + param.getClass().getName());

} else if (matcher2.matches()) { // without bounds

final String shape = matcher2.group(1); // null for scalar
final String valuesAsString = matcher2.group(2);
final String[] valuesStr = valuesAsString.split(" ");

if (param instanceof BoolScalarParam boolScalar) {
} else if (param instanceof RealVectorParam<?> realVectorParam) {
realVectorParam.fromXML(shape, valuesStr);
} else if (param instanceof IntVectorParam<?> intVectorParam) {
intVectorParam.fromXML(shape, valuesStr);
} else if (param instanceof BoolScalarParam boolScalar) {
boolScalar.fromXML(valuesStr[0]);
} else if (param instanceof BoolVectorParam boolVector) {
boolVector.fromXML(valuesStr);
//TODO } else if (param instanceof BoolMatrixParam boolMatrix) {
} else
throw new RuntimeException("Unknown parameter type : " + param.getClass().getName());

} else {
throw new RuntimeException("String could not be parsed to parameter : " + str);
}
}

/**
* This method is used by {@link StateNode#toString()}.
* @see #parseParameter(Node, StateNode)
* @param param a parameter
* @return kappa{[0.0,Infinity]}: 29 or
* freqs{4, [0.0,1.0]}: 0.25 0.25 0.25 0.25
* Serializes a parameter to a bound-free string for state-file persistence.
* Must stay consistent with {@link #parseParameter(Node, StateNode)}.
* <p>
* Format:
* <ul>
* <li>scalar: {@code kappa: 29}</li>
* <li>vector: {@code freqs{4}: 0.25 0.25 0.25 0.25}</li>
* <li>boolean scalar: {@code isEstimated: true}</li>
* <li>boolean vector: {@code isSelected{2}: true false}</li>
* </ul>
* Bounds are not written; they are derived from the domain at runtime.
*
* @param param a parameter
* @return the serialized string
*/
public static String paramToString(StateNode param) {
String str = param.getID();
if (param instanceof Tensor<?,?> tensor) {
str += "{";
String shapeStr = TypeUtils.shapeToString(tensor);
// empty for scalar, or size of vector, or [3, 4] for matrix.
// scalar and vector use shape for validation, but it is compulsory for matrix.
// Empty for scalars; size for vectors; [r,c] for matrices.
// Scalars drop the braces entirely; vectors and matrices keep them
// so parseParameter can validate the element count on restore.
if (!shapeStr.isEmpty())
str += shapeStr + ", ";
// if (param instanceof BoundedParam<?> boundedParam)
// str += boundedParam.boundsToString();
// check if nothing inside { }
str += shapeStr;
if (str.endsWith("{"))
str = str.substring(0, str.length() - 1);
str = str.substring(0, str.length() - 1); // scalar: remove empty braces
else
str += "}"; // close {
str += "}";
}
str += ": ";
if (param instanceof Scalar scalar)
if (param instanceof Scalar scalar)
str += scalar.get();
else if (param instanceof Vector vector) {
else if (param instanceof Vector vector) {
List elements = vector.getElements();
for (Object element : elements)
str += element.toString() + " ";
}

return str; //+ " ";
return str;
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package beast.base.spec.inference.parameter;

import beast.base.spec.domain.Int;
import beast.base.spec.domain.PositiveReal;
import beast.base.spec.domain.Real;
import org.junit.jupiter.api.Test;
import org.w3c.dom.Document;
import org.w3c.dom.Element;

import javax.xml.parsers.DocumentBuilderFactory;

import static org.junit.jupiter.api.Assertions.*;

/**
* Tests for {@link ParameterUtils#parseParameter} and {@link ParameterUtils#paramToString},
* covering the full round-trip for all supported parameter types and verifying that
* legacy BEAST2 state files with explicit bounds are rejected.
*/
public class ParameterUtilsTest {

// ------------------------------------------------------------------ helpers

private static org.w3c.dom.Node createNode(String id, String textContent) throws Exception {
Document doc = DocumentBuilderFactory.newInstance().newDocumentBuilder().newDocument();
Element el = doc.createElement("stateNode");
el.setAttribute("id", id);
el.setTextContent(textContent);
return el;
}

// ------------------------------------------------------------------ scalar round-trips

@Test
void testRealScalarRoundTrip() throws Exception {
RealScalarParam<PositiveReal> src = new RealScalarParam<>(1.5, PositiveReal.INSTANCE);
src.setID("kappa");

String str = src.toString();
assertEquals("kappa: 1.5", str);

RealScalarParam<PositiveReal> target = new RealScalarParam<>(0.1, PositiveReal.INSTANCE);
ParameterUtils.parseParameter(createNode("kappa", str), target);

assertEquals(1.5, target.get(), 1e-12);
assertEquals("kappa", target.getID());
}

@Test
void testIntScalarRoundTrip() throws Exception {
IntScalarParam<Int> src = new IntScalarParam<>(7, Int.INSTANCE);
src.setID("popSize");

String str = src.toString();
assertEquals("popSize: 7", str);

IntScalarParam<Int> target = new IntScalarParam<>(1, Int.INSTANCE);
ParameterUtils.parseParameter(createNode("popSize", str), target);

assertEquals(7, target.get());
assertEquals("popSize", target.getID());
}

@Test
void testBoolScalarRoundTrip() throws Exception {
BoolScalarParam src = new BoolScalarParam(true);
src.setID("isEstimated");

String str = src.toString();
assertEquals("isEstimated: true", str);

BoolScalarParam target = new BoolScalarParam(false);
ParameterUtils.parseParameter(createNode("isEstimated", str), target);

assertTrue(target.get());
assertEquals("isEstimated", target.getID());
}

// ------------------------------------------------------------------ vector round-trips

@Test
void testRealVectorRoundTrip() throws Exception {
RealVectorParam<Real> src = new RealVectorParam<>(new double[]{0.1, 0.2, 0.3, 0.4}, Real.INSTANCE);
src.setID("freqs");

String str = src.toString();
// shape must appear so parseParameter can validate element count
assertTrue(str.startsWith("freqs{4}: "), "Expected 'freqs{4}: ...' but was: " + str);

RealVectorParam<Real> target = new RealVectorParam<>(new double[]{0.25, 0.25, 0.25, 0.25}, Real.INSTANCE);
ParameterUtils.parseParameter(createNode("freqs", str), target);

assertEquals(0.1, target.get(0), 1e-12);
assertEquals(0.2, target.get(1), 1e-12);
assertEquals(0.3, target.get(2), 1e-12);
assertEquals(0.4, target.get(3), 1e-12);
assertEquals("freqs", target.getID());
}

@Test
void testIntVectorRoundTrip() throws Exception {
IntVectorParam<Int> src = new IntVectorParam<>(new int[]{1, 2, 3}, Int.INSTANCE);
src.setID("counts");

String str = src.toString();
assertTrue(str.startsWith("counts{3}: "), "Expected 'counts{3}: ...' but was: " + str);

IntVectorParam<Int> target = new IntVectorParam<>(new int[]{0, 0, 0}, Int.INSTANCE);
ParameterUtils.parseParameter(createNode("counts", str), target);

assertEquals(1, target.get(0));
assertEquals(2, target.get(1));
assertEquals(3, target.get(2));
assertEquals("counts", target.getID());
}

@Test
void testBoolVectorRoundTrip() throws Exception {
BoolVectorParam src = new BoolVectorParam(new boolean[]{true, false, true});
src.setID("isSelected");

String str = src.toString();
assertTrue(str.startsWith("isSelected{3}: "), "Expected 'isSelected{3}: ...' but was: " + str);

BoolVectorParam target = new BoolVectorParam(new boolean[]{false, false, false});
ParameterUtils.parseParameter(createNode("isSelected", str), target);

assertTrue(target.get(0));
assertFalse(target.get(1));
assertTrue(target.get(2));
assertEquals("isSelected", target.getID());
}

// ------------------------------------------------------------------ legacy bounds rejection

@Test
void testLegacyScalarBoundsThrows() throws Exception {
// BEAST2 format: explicit bounds in braces
String legacy = "kappa{[0.0,Infinity]}: 1.5";
RealScalarParam<PositiveReal> param = new RealScalarParam<>(1.0, PositiveReal.INSTANCE);
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class,
() -> ParameterUtils.parseParameter(createNode("kappa", legacy), param));
assertTrue(ex.getMessage().contains("explicit bounds"));
assertTrue(ex.getMessage().contains("prior distribution"));
}

@Test
void testLegacyVectorBoundsThrows() throws Exception {
// BEAST2 format: shape + explicit bounds
String legacy = "freqs{4, [0.0,1.0]}: 0.25 0.25 0.25 0.25";
RealVectorParam<Real> param = new RealVectorParam<>(new double[]{0.25, 0.25, 0.25, 0.25}, Real.INSTANCE);
assertThrows(IllegalArgumentException.class,
() -> ParameterUtils.parseParameter(createNode("freqs", legacy), param));
}

// ------------------------------------------------------------------ paramToString format

@Test
void testParamToStringScalarHasNoBraces() {
// Scalars must not emit {}: the shape is implicit (rank 0)
RealScalarParam<Real> param = new RealScalarParam<>(2.5, Real.INSTANCE);
param.setID("mu");
assertFalse(param.toString().contains("{"), "Scalar toString must not contain braces");
}

@Test
void testParamToStringVectorHasShapeNoBoundsComma() {
// Vectors must emit {N} — no trailing comma from the old BoundedParam format
RealVectorParam<Real> param = new RealVectorParam<>(new double[]{1.0, 2.0}, Real.INSTANCE);
param.setID("rates");
String s = param.toString();
assertTrue(s.contains("{2}"), "Vector toString must contain '{2}', got: " + s);
assertFalse(s.contains("{2,"), "Vector toString must not contain legacy '{2,' format, got: " + s);
}
}
Loading