diff --git a/osgar/bus.py b/osgar/bus.py index cea0cfe71..d5b8787d5 100644 --- a/osgar/bus.py +++ b/osgar/bus.py @@ -75,10 +75,26 @@ def register(self, *outputs): continue idx = self.logger.register(f'{self.name}.{o}') self.stream_id[o] = idx - if name_and_type.endswith(':null'): + + # Resolve configuration overrides if provided + config_name_and_type = name_and_type + if hasattr(self, 'config_out') and self.config_out: + for item in self.config_out: + if item == o or item.startswith(o + ':'): + # ONLY override if the item in config explicitly specifies a modifier/suffix (contains a colon) + if ':' in item: + config_name_and_type = item + break + + if config_name_and_type.endswith(':null'): self.no_output.add(idx) - if name_and_type.endswith(':gz'): + self.compressed_output.discard(idx) + elif config_name_and_type.endswith(':gz'): self.compressed_output.add(idx) + self.no_output.discard(idx) + else: + self.no_output.discard(idx) + self.compressed_output.discard(idx) self.out[o] = [] self.slots[o] = [] diff --git a/osgar/record.py b/osgar/record.py index e71b71bba..b7b5c53e8 100644 --- a/osgar/record.py +++ b/osgar/record.py @@ -29,7 +29,10 @@ def __init__(self, config, logger): if env is not None: assert 'env' not in module_config_init, module_config_init module_config_init['env'] = env.copy() - self.modules[module_name] = klass(module_config_init, bus=self.bus.handle(module_name)) + bus_handle = self.bus.handle(module_name) + if 'out' in module_config: + bus_handle.config_out = module_config['out'] + self.modules[module_name] = klass(module_config_init, bus=bus_handle) for sender, receiver in config['links']: self.bus.connect(sender, receiver, self.modules) diff --git a/osgar/test_bus.py b/osgar/test_bus.py index 84935491b..a5bb01cd5 100644 --- a/osgar/test_bus.py +++ b/osgar/test_bus.py @@ -205,4 +205,35 @@ def sleep(): t.join(0.01) self.assertFalse(t.is_alive()) + def test_config_overrides(self): + logger = MagicMock() + logger.register = MagicMock(side_effect=[10, 11, 12, 13]) + bus = Bus(logger) + handle = bus.handle('test') + handle.config_out = ["raw:gz", "status:null", "preserved", "depth:"] + + handle.register("raw", "status", "preserved:gz", "depth:gz") + + self.assertIn("raw", handle.stream_id) + self.assertIn("status", handle.stream_id) + self.assertIn("preserved", handle.stream_id) + self.assertIn("depth", handle.stream_id) + + raw_idx = handle.stream_id["raw"] + status_idx = handle.stream_id["status"] + preserved_idx = handle.stream_id["preserved"] + depth_idx = handle.stream_id["depth"] + + self.assertIn(raw_idx, handle.compressed_output) + self.assertNotIn(raw_idx, handle.no_output) + + self.assertIn(status_idx, handle.no_output) + self.assertNotIn(status_idx, handle.compressed_output) + + self.assertIn(preserved_idx, handle.compressed_output) + self.assertNotIn(preserved_idx, handle.no_output) + + self.assertNotIn(depth_idx, handle.compressed_output) + self.assertNotIn(depth_idx, handle.no_output) + # vim: expandtab sw=4 ts=4 diff --git a/osgar/test_zmqrouter.py b/osgar/test_zmqrouter.py index ee8567419..24f6cd2a2 100644 --- a/osgar/test_zmqrouter.py +++ b/osgar/test_zmqrouter.py @@ -215,6 +215,35 @@ def test_compress(self): } record(config, log_filename='compressed-publisher.log') + def test_config_overrides(self): + config = { 'version': 2, 'robot': { 'modules': {}, 'links': [] } } + config['robot']['modules']['publisher'] = { + "driver": "osgar.test_zmqrouter:Publisher", + "init": { "output": "count" }, + "out": ["count:gz"] + } + record(config, log_filename='overridden-publisher.log') + + import msgpack + from osgar.lib.serialize import deserialize + + with osgar.logger.LogReader(self.tempdir/"overridden-publisher.log", only_stream_id=1) as log: + last_dt = datetime.timedelta() + count = 0 + for dt, channel, data in log: + self.assertGreater(dt, last_dt) + + # Verify that the raw data in the log is indeed compressed (it should be an ExtType with code 42) + unpacked = msgpack.unpackb(data, raw=False) + self.assertIsInstance(unpacked, msgpack.ExtType) + self.assertEqual(unpacked.code, 42) + + # Verify that we can successfully deserialize it back to the original count + self.assertEqual(deserialize(data), count) + + last_dt = dt + count += 1 + def test_null(self): config = { 'version': 2, 'robot': { 'modules': {}, 'links': [] } } config['robot']['modules']['publisher'] = { diff --git a/osgar/zmqrouter.py b/osgar/zmqrouter.py index 1312f8680..8111b534e 100644 --- a/osgar/zmqrouter.py +++ b/osgar/zmqrouter.py @@ -49,7 +49,7 @@ def record(config, log_prefix=None, log_filename=None, duration_sec=None): with osgar.logger.LogWriter(prefix=log_prefix, filename=log_filename, note=str(sys.argv)) as log: log.write(0, bytes(str(config), 'ascii')) g_logger.info(log.filename) - with _Router(log) as router: + with _Router(log, config=config['robot']['modules']) as router: modules = {} s = sys.modules[__name__].__spec__ for module_name, module_config in config['robot']['modules'].items(): @@ -83,8 +83,9 @@ def record(config, log_prefix=None, log_filename=None, duration_sec=None): class _Router: - def __init__(self, logger): + def __init__(self, logger, config=None): self.logger = logger + self.config = config self.start_time = self.logger.start_time self._context = zmq.Context() self.socket = self._context.socket(zmq.ROUTER) @@ -125,16 +126,41 @@ def register_nodes(self, nodes, timeout=datetime.timedelta.max): assert sender in nodes, (sender, nodes) # it is one of the nodes we expect self.nodes[sender] = collections.deque() # receiving queue self.delays[sender] = datetime.timedelta() + + node_name = str(sender, 'ascii') + module_config = self.config.get(node_name, {}) if self.config else {} + config_out = module_config.get('out', []) + for name_and_type in args: o, *suffix = name_and_type.split(b':') suffix = suffix[0] if suffix else b'' + + # Resolve overrides from configuration + o_str = str(o, 'ascii') + config_suffix = None + for item in config_out: + if item == o_str: + # Plain name (no colon modifier) - preserve driver default suffix + break + elif item.startswith(o_str + ':'): + # Overridden with a specific suffix or a trailing empty colon (e.g. 'depth:') + config_suffix = bytes(item.split(':', 1)[1], 'ascii') + break + if config_suffix is not None: + suffix = config_suffix + link_from = sender + b"." + o idx = self.logger.register(str(link_from, 'ascii')) self.stream_id[link_from] = idx if suffix == b'null': self.no_output.add(idx) + self.compressed_output.discard(idx) elif suffix == b'gz': self.compressed_output.add(idx) + self.no_output.discard(idx) + else: + self.no_output.discard(idx) + self.compressed_output.discard(idx) if self.nodes.keys() == nodes: return raise RuntimeError("unexpected stop")