""" Helper functions that make constructing hardware easier.
"""
from __future__ import annotations
import collections
import math
import numbers
import sys
from functools import reduce
from .core import working_block, _NameIndexer, _get_debug_mode, Block
from .pyrtlexceptions import PyrtlError, PyrtlInternalError
from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector
from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list
# -----------------------------------------------------------------
# ___ __ ___ __ __
# |__| |__ | |__) |__ |__) /__`
# | | |___ |___ | |___ | \ .__/
#
probeIndexer = _NameIndexer('Probe-')
[docs]
def probe(w, name=None):
""" Print useful information about a WireVector when in debug mode.
:param WireVector w: WireVector from which to get info
:param str name: optional name for probe (defaults to an autogenerated name)
:return: original WireVector `w`
:rtype: WireVector
Probe can be inserted into a existing design easily as it returns the
original wire unmodified. For example ``y <<= x[0:3] + 4`` could be turned
into ``y <<= probe(x)[0:3] + 4`` to give visibility into both the origin of
``x`` (including the line that WireVector was originally created) and
the run-time values of ``x`` (which will be named and thus show up by
default in a trace). Likewise ``y <<= probe(x[0:3]) + 4``,
``y <<= probe(x[0:3] + 4)``, and ``probe(y) <<= x[0:3] + 4`` are all
valid uses of `probe`.
Note: `probe` does actually add an Output wire to the working block of `w` (which can
confuse various post-processing transforms such as output to verilog).
"""
if not isinstance(w, WireVector):
raise PyrtlError('Only WireVectors can be probed')
if name is None:
name = '(%s: %s)' % (probeIndexer.make_valid_string(), w.name)
if _get_debug_mode():
print("Probe: " + name + ' ' + get_stack(w))
p = Output(name=name)
p <<= w # late assigns len from w automatically
return w
assertIndexer = _NameIndexer('assertion')
[docs]
def rtl_assert(w, exp, block=None):
""" Add hardware assertions to be checked on the RTL design.
:param WireVector w: should be a WireVector
:param Exception exp: Exception to throw when assertion fails
:param Block block: block to which the assertion should be added (default to working block)
:return: the Output wire for the assertion (can be ignored in most cases)
:rtype: Output
If at any time during execution the wire `w` is not `true` (i.e. asserted low)
then simulation will raise `exp`.
"""
block = working_block(block)
if not isinstance(w, WireVector):
raise PyrtlError('Only WireVectors can be asserted with rtl_assert')
if len(w) != 1:
raise PyrtlError('rtl_assert checks only a WireVector of bitwidth 1')
if not isinstance(exp, Exception):
raise PyrtlError('the second argument to rtl_assert must be an instance of Exception')
if isinstance(exp, KeyError):
raise PyrtlError('the second argument to rtl_assert cannot be a KeyError')
if w not in block.wirevector_set:
raise PyrtlError('assertion wire not part of the block to which it is being added')
if w not in block.wirevector_set:
raise PyrtlError('assertion not a known wirevector in the target block')
if w in block.rtl_assert_dict:
raise PyrtlInternalError('assertion conflicts with existing registered assertion')
assert_wire = Output(bitwidth=1, name=assertIndexer.make_valid_string(), block=block)
assert_wire <<= w
block.rtl_assert_dict[assert_wire] = exp
return assert_wire
[docs]
def check_rtl_assertions(sim):
""" Checks the values in sim to see if any registers assertions fail.
:param Simulation sim: Simulation in which to check the assertions
:return: None
"""
for (w, exp) in sim.block.rtl_assert_dict.items():
try:
value = sim.inspect(w)
if not value:
raise exp
except KeyError:
pass
[docs]
def log2(integer_val):
""" Return the log base 2 of the integer provided.
:param int integer_val: The integer to take the log base 2 of.
:return: The log base 2 of `integer_val`, or throw PyRTL error if not power of 2
:rtype: int
This function is useful when checking that powers of 2 are provided on inputs to functions.
It throws an error if a negative value is provided or if the value provided is not an even
power of two.
Examples: ::
log2(2) # returns 1
log2(256) # returns 8
addrwidth = log2(size_of_memory) # will fail if size_of_memory is not a power of two
"""
i = integer_val
if not isinstance(i, int):
raise PyrtlError('this function can only take integers')
if i <= 0:
raise PyrtlError('this function can only take positive numbers 1 or greater')
if i & (i - 1) != 0:
raise PyrtlError('this function can only take even powers of 2')
return i.bit_length() - 1
[docs]
def truncate(wirevector_or_integer, bitwidth):
""" Returns a WireVector or integer truncated to the specified bitwidth
:param wirevector_or_integer: Either a WireVector or an integer to be truncated.
:param int bitwidth: The length to which the first argument should be truncated.
:return: A truncated WireVector or integer as appropriate.
This function truncates the most significant bits of the input, leaving a result
that is only `bitwidth` bits wide. For integers this is performed with a simple
bitmask of size `bitwidth`. For WireVectors the function calls :meth:`.WireVector.truncate`
and returns a WireVector of the specified `bitwidth`.
Examples: ::
truncate(9,3) # returns 1 (0b1001 truncates to 0b001)
truncate(5,3) # returns 5 (0b101 truncates to 0b101)
truncate(-1,3) # returns 7 (-0b1 truncates to 0b111)
y = truncate(x+1, x.bitwidth) # y.bitwdith will equal x.bitwidth
"""
if bitwidth < 1:
raise PyrtlError('bitwidth must be a positive integer')
x = wirevector_or_integer
try:
return x.truncate(bitwidth)
except AttributeError:
return x & ((1 << bitwidth) - 1)
class MatchedFields(collections.namedtuple('MatchedFields', 'matched fields')):
def __enter__(self):
from .conditional import _push_condition
_push_condition(self.matched)
return self.fields
def __exit__(self, *execinfo):
from .conditional import _pop_condition
_pop_condition()
[docs]
def match_bitpattern(w, bitpattern, field_map=None):
"""Returns a single-bit WireVector that is 1 if and only if `w` matches the
`bitpattern`, and a tuple containing the matched fields, if any. Compatible
with the `with` statement.
:param WireVector w: The WireVector to be compared to the bitpattern
:param str bitpattern: A string holding the pattern (of bits and wildcards) to
match
:param field_map: (optional) A map from single-character field name in the
bitpattern to the desired name of field in the returned namedtuple. If
given, all non-"1"/"0"/"?" characters in the `bitpattern` must be
present in the map.
:return: A tuple of 1-bit WireVector carrying the result of the comparison,
followed by a named tuple containing the matched fields, if any.
This function will compare a multi-bit WireVector to a specified pattern of
bits, where some of the pattern can be "wildcard" bits. If any of the
``1`` or ``0`` values specified in the bitpattern fail to match the
WireVector during execution, a ``0`` will be produced, otherwise the value
carried on the wire will be ``1``. The wildcard characters can be any
other alphanumeric character, with characters other than ``?`` having
special functionality (see below). The string must have length equal to
the WireVector specified, although whitespace and underscore characters
will be ignored and can be used for pattern readability.
For all other characters besides ``1``, ``0``, or ``?``, a tuple of
WireVectors will be returned as the second return value. Each character
will be treated as the name of a field, and non-consecutive fields with the
same name will be concatenated together, left-to-right, into a single field
in the resultant tuple. For example, ``01aa1?bbb11a`` will match a string
such as ``010010100111``, and the resultant matched fields are::
(a, b) = (0b001, 0b100)
where the ``a`` field is the concenation of bits 9, 8, and 0, and the ``b``
field is the concenation of bits 5, 4, and 3. Thus, arbitrary characters
beside ``?`` act as wildcard characters for the purposes of matching, with
the additional benefit of returning the WireVectors corresponding to those
fields.
A prime example of this is for decoding instructions. Here we decode some
RISC-V: ::
with pyrtl.conditional_assignment:
with match_bitpattern(inst, "iiiiiiiiiiiirrrrr010ddddd0000011") as (imm, rs1, rd):
regfile[rd] |= mem[(regfile[rs1] + imm.sign_extended(32)).truncate(32)]
pc.next |= pc + 1
with match_bitpattern(inst, "iiiiiiirrrrrsssss010iiiii0100011") as (imm, rs2, rs1):
mem[(regfile[rs1] + imm.sign_extended(32)).truncate(32)] |= regfile[rs2]
pc.next |= pc + 1
with match_bitpattern(inst, "0000000rrrrrsssss111ddddd0110011") as (rs2, rs1, rd):
regfile[rd] |= regfile[rs1] & regfile[rs2]
pc.next |= pc + 1
with match_bitpattern(inst, "0000000rrrrrsssss000ddddd0110011") as (rs2, rs1, rd):
regfile[rd] |= (regfile[rs1] + regfile[rs2]).truncate(32)
pc.next |= pc + 1
# ...etc...
Some smaller examples: ::
m, _ = match_bitpattern(w, '0101') # basically the same as w == '0b0101'
m, _ = match_bitpattern(w, '01?1') # m will be true when w is '0101' or '0111'
m, _ = match_bitpattern(w, '??01') # m be true when last two bits of w are '01'
m, _ = match_bitpattern(w, '??_0 1') # spaces/underscores are ignored, same as line above
m, (a, b) = match_pattern(w, '01aa1?bbb11a') # all bits with same letter make up same field
m, fs = match_pattern(w, '01aa1?bbb11a', {'a': 'foo', 'b': 'bar'}) # fields fs.foo, fs.bar
"""
w = as_wires(w)
if not isinstance(bitpattern, str):
raise PyrtlError('bitpattern must be a string')
nospace_string = ''.join(bitpattern.replace('_', '').split())
if len(w) != len(nospace_string):
raise PyrtlError('bitpattern string different length than wirevector provided')
lsb_first_string = nospace_string[::-1] # flip so index 0 is lsb
zero_bits = [w[index] for index, x in enumerate(lsb_first_string) if x == '0']
one_bits = [w[index] for index, x in enumerate(lsb_first_string) if x == '1']
match = rtl_all(*one_bits) & ~rtl_any(*zero_bits)
# Since only Python 3.7 and above guarantees maintaining insertion order in dictionaries,
# do all of this to make sure we can maintain the ordering in the returned Tuple.
# Order of fields is determined based on left-to-right ordering in original string.
def field_name(name):
if field_map is not None:
if name not in field_map:
raise PyrtlError('field_map argument has been given, '
'but %s field is not present' % name)
return field_map[name]
return name
fields = collections.defaultdict(list)
for i, c in enumerate(lsb_first_string):
if c not in '01?':
fields[c].append(w[i])
fields = sorted(fields.items(), key=lambda m: nospace_string.index(m[0])) # now list of tuples
Fields = collections.namedtuple('Fields', ' '.join(field_name(name) for name, _ in fields))
fields = Fields(**{field_name(k): concat_list(l) for k, l in fields})
return MatchedFields(match, fields)
def bitpattern_to_val(bitpattern, *ordered_fields, **named_fields):
"""Return an unsigned integer representation of field format filled with
the provided values.
:param bitpattern: A string holding the pattern (of bits and wildcards) to
match
:param ordered_fields: A list of parameters to be matched to the provided
bit pattern in the order provided. If ordered_fields are provided then
no named_fields can be used.
:param named_fields: A list of parameters to be matched to the provided bit
pattern by the names provided. If named_fields are provided then no
ordered_fields can be used. A special keyword argument, 'field_map',
can be provided, which will allow you to specify a correspondence
between the 1-letter field names in the bitpattern string and longer,
human readable field names (see example below).
:return: An unsigned integer carrying the result of the field substitution.
:rtype: int
This function will compare a specified pattern of bits, where some of the
pattern can be "wildcard" bits. The wildcard bits must all be named with a
single letter and, unlike the related function ``match_bitpattern``, no "?"
can be used. The function will take the provided bitpattern and create an
integer that substitutes the provided fields in for the given wildcards at
the bit level. This sort of bit substitution is useful when creating
values for testing when the resulting values will be "chopped" up by the
hardware later (e.g. instruction decode or other bitfield heavy functions).
If a special keyword argument, 'field_map', is provided, then the named
fields provided can be longer, human-readable field names, which will
correspond to the field in the bitpattern according to the field_map. See
the third example below.
Examples::
bitpattern_to_val('0000000sssssrrrrr000ddddd0110011', s=2, r=1, d=3) # RISCV ADD instr
# evaluates to 0b00000000000100010000000110110011
bitpattern_to_val('iiiiiiisssssrrrrr010iiiii0100011', i=1, s=4, r=3) # RISCV SW instr
# evaluates to 0b00000000001100100010000010100011
bitpattern_to_val(
'iiiiiiisssssrrrrr010iiiii0100011',
imm=1, rs2=4, rs1=3,
field_map={'i': 'imm', 's': 'rs2', 'r': 'rs1}
) # RISCV SW instr
# evaluates to 0b00000000001100100010000010100011
"""
if not bitpattern:
raise PyrtlError('bitpattern must be nonempty')
if len(ordered_fields) > 0 and len(named_fields) > 0:
raise PyrtlError('named and ordered fields cannot be mixed')
def letters_in_field_order():
seen = []
for c in bitpattern:
if c != '0' and c != '1' and c not in seen:
seen.append(c)
return seen
field_map = None
if 'field_map' in named_fields:
field_map = named_fields['field_map']
named_fields.pop('field_map')
bitlist = []
lifo = letters_in_field_order()
if ordered_fields:
if len(lifo) != len(ordered_fields):
raise PyrtlError('number of fields and number of unique patterns do not match')
intfields = [int(f) for f in ordered_fields]
else:
if len(lifo) != len(named_fields):
raise PyrtlError('number of fields and number of unique patterns do not match')
try:
def fn(n):
return field_map[n] if field_map else n
intfields = [int(named_fields[fn(n)]) for n in lifo]
except KeyError as e:
raise PyrtlError('bitpattern field %s was not provided in named_field list' % e.args[0])
fmap = dict(zip(lifo, intfields))
for c in bitpattern[::-1]:
if c == '0' or c == '1':
bitlist.append(c)
elif c == '?':
raise PyrtlError('all fields in the bitpattern must have names')
else:
bitlist.append(str(fmap[c] & 0x1)) # append lsb of the field
fmap[c] = fmap[c] >> 1 # and bit shift by one position
for f in fmap:
if fmap[f] not in [0, -1]:
raise PyrtlError('too many bits given to value to fit in field %s' % f)
if len(bitpattern) != len(bitlist):
raise PyrtlInternalError('resulting values have different bitwidths')
final_str = ''.join(bitlist[::-1])
return int(final_str, 2)
[docs]
def chop(w, *segment_widths):
""" Returns a list of WireVectors, each a slice of the original `w`.
:param WireVector w: The WireVector to be chopped up into segments
:param int segment_widths: Additional arguments are integers which are bitwidths
:return: A list of WireVectors each with a proper segment width
:rtype: List[WireVector]
This function chops a WireVector into a set of smaller WireVectors of different
lengths. It is most useful when multiple "fields" are contained with a single
WireVector, for example when breaking apart an instruction. For example, if
you wish to break apart a 32-bit MIPS I-type (Immediate) instruction you know
it has an 6-bit opcode, 2 5-bit operands, and 16-bit offset. You could take
each of those slices in absolute terms: ``offset=instr[0:16]``, ``rt=instr[16:21]``
and so on, but then you have to do the arithmetic yourself. With this function
you can do all the fields at once which can be seen in the examples below.
As a check, chop will throw an error if the sum of the lengths of the fields
given is not the same as the length of the WireVector to chop. Note also that
chop assumes that the "rightmost" arguments are the least signficant bits
(just like :func:`.concat`) which is normal for hardware functions but makes the
list order a little counter intuitive.
Examples: ::
opcode, rs, rt, offset = chop(instr, 6, 5, 5, 16) # MIPS I-type instruction
opcode, instr_index = chop(instr, 6, 26) # MIPS J-type instruction
opcode, rs, rt, rd, sa, function = chop(instr, 6, 5, 5, 5, 5, 6) # MIPS R-type
msb, middle, lsb = chop(data, 1, 30, 1) # break out the most and least significant bits
"""
w = as_wires(w)
for seg in segment_widths:
if not isinstance(seg, int):
raise PyrtlError('segment widths must be integers')
if sum(segment_widths) != len(w):
raise PyrtlError('sum of segment widths must equal length of wirevetor')
n_segments = len(segment_widths)
starts = [sum(segment_widths[i + 1:]) for i in range(n_segments)]
ends = [sum(segment_widths[i:]) for i in range(n_segments)]
return [w[s:e] for s, e in zip(starts, ends)]
[docs]
def output_list(names, bitwidth=None):
""" Allocate and return a list of :class:`Outputs<pyrtl.wire.Output>`.
:param names: Names for the Outputs. Can be a list or single comma/space-separated string
:param int bitwidth: The desired bitwidth for the resulting Outputs.
:return: List of Outputs.
:rtype: List[Output]
Equivalent to: ::
wirevector_list(names, bitwidth, wvtype=pyrtl.wire.Output)
"""
return wirevector_list(names, bitwidth, wvtype=Output)
[docs]
def register_list(names, bitwidth=None):
""" Allocate and return a list of :class:`Registers<pyrtl.wire.Register>`.
:param names: Names for the Registers. Can be a list or single comma/space-separated string
:param int bitwidth: The desired bitwidth for the resulting Registers.
:return: List of Registers.
:rtype: List[Register]
Equivalent to: ::
wirevector_list(names, bitwidth, wvtype=pyrtl.wire.Register)
"""
return wirevector_list(names, bitwidth, wvtype=Register)
[docs]
def wirevector_list(names, bitwidth=None, wvtype=WireVector):
""" Allocate and return a list of WireVectors.
:param names: Names for the WireVectors. Can be a list or single comma/space-separated string
:param int bitwidth: The desired bitwidth for the resulting WireVectors.
:param WireVector wvtype: Which WireVector type to create.
:return: List of WireVectors.
:rtype: List[WireVector]
Additionally, the `names` string can also contain an additional bitwidth specification
separated by a ``/`` in the name. This cannot be used in combination with a `bitwidth`
value other than ``1``.
Examples: ::
wirevector_list(['name1', 'name2', 'name3'])
wirevector_list('name1, name2, name3')
wirevector_list('input1 input2 input3', bitwidth=8, wvtype=pyrtl.wire.Input)
wirevector_list('output1, output2 output3', bitwidth=3, wvtype=pyrtl.wire.Output)
wirevector_list('two_bits/2, four_bits/4, eight_bits/8')
wirevector_list(['name1', 'name2', 'name3'], bitwidth=[2, 4, 8])
"""
if isinstance(names, str):
names = names.replace(',', ' ').split()
if any('/' in name for name in names) and bitwidth is not None:
raise PyrtlError('only one of optional "/" or bitwidth parameter allowed')
if bitwidth is None:
bitwidth = 1
if isinstance(bitwidth, numbers.Integral):
bitwidth = [bitwidth] * len(names)
if len(bitwidth) != len(names):
raise ValueError('number of names ' + str(len(names))
+ ' should match number of bitwidths ' + str(len(bitwidth)))
wirelist = []
for fullname, bw in zip(names, bitwidth):
try:
name, bw = fullname.split('/')
except ValueError:
name, bw = fullname, bw
wirelist.append(wvtype(bitwidth=int(bw), name=name))
return wirelist
[docs]
def val_to_signed_integer(value: int, bitwidth: int) -> int:
"""Return value as intrepreted as a signed integer under two's complement.
:param value: A Python integer holding the value to convert.
:param bitwidth: The length of the integer in bits to assume for
conversion.
:return: ``value`` as a signed integer
Given an unsigned integer (not a ``WireVector``!) convert that to a signed
integer. This is useful for printing and interpreting values which are
negative numbers in two's complement. ::
val_to_signed_integer(0xff, 8) == -1
``val_to_signed_integer`` can also be used as an ``repr_func`` for
:py:meth:`.SimulationTrace.render_trace`, to display signed integers in
traces::
bitwidth = 3
counter = Register(name='counter', bitwidth=bitwidth)
counter.next <<= counter + 1
sim = Simulation()
sim.step_multiple(nsteps=2 ** bitwidth)
# Generates a trace like:
# │0 │1 │2 │3 │4 │5 │6 │7
#
# counter ──┤1 │2 │3 │-4│-3│-2│-1
sim.tracer.render_trace(repr_func=val_to_signed_integer)
"""
if isinstance(value, WireVector) or isinstance(bitwidth, WireVector):
raise PyrtlError('inputs must not be wirevectors')
if bitwidth < 1:
raise PyrtlError('bitwidth must be a positive integer')
neg_mask = 1 << (bitwidth - 1)
neg_part = value & neg_mask
pos_mask = neg_mask - 1
pos_part = value & pos_mask
return pos_part - neg_part
# this is the return type of value_bitwidth_tuple
ValueBitwidthTuple = collections.namedtuple('ValueBitwidthTuple', 'value bitwidth')
[docs]
def infer_val_and_bitwidth(rawinput, bitwidth=None, signed=False):
""" Return a tuple (value, bitwidth) infered from the specified input.
:param rawinput: a bool, int, or verilog-style string constant
:param int bitwidth: an integer bitwidth or (by default) None
:param bool signed: a bool (by default set False) to include bits for proper two's complement
:return: tuple of integers (`value`, `bitwidth`)
:rtype: (int, int)
Given a boolean, integer, or verilog-style string constant, this function returns a
tuple of two integers (`value`, `bitwidth`) which are infered from the specified `rawinput`.
The tuple returned is, in fact, a named tuple with names `.value` and `.bitwidth` for fields
0 and 1 respectively. If `signed` is True, bits will be included to ensure a proper
two's complement representation is possible, otherwise it is assume all bits can be used
for standard unsigned representation. Error checks are performed that determine if the
bitwidths specified are sufficient and appropriate for the values specified.
Examples can be found below ::
infer_val_and_bitwidth(2, bitwidth=5) == (2, 5)
infer_val_and_bitwidth(3) == (3, 2) # bitwidth infered from value
infer_val_and_bitwidth(3, signed=True) == (3, 3) # need a bit for the leading zero
infer_val_and_bitwidth(-3, signed=True) == (5, 3) # 5 = -3 & 0b111 = ..111101 & 0b111
infer_val_and_bitwidth(-4, signed=True) == (4, 3) # 4 = -4 & 0b111 = ..111100 & 0b111
infer_val_and_bitwidth(-3, bitwidth=5, signed=True) == (29, 5)
infer_val_and_bitwidth(-3) ==> Error # negative numbers require bitwidth or signed=True
infer_val_and_bitwidth(3, bitwidth=2) == (3, 2)
infer_val_and_bitwidth(3, bitwidth=2, signed=True) ==> Error # need space for sign bit
infer_val_and_bitwidth(True) == (1, 1)
infer_val_and_bitwidth(False) == (0, 1)
infer_val_and_bitwidth("5'd12") == (12, 5)
infer_val_and_bitwidth("5'b10") == (2, 5)
infer_val_and_bitwidth("5'b10").bitwidth == 5
infer_val_and_bitwidth("5'b10").value == 2
infer_val_and_bitwidth("8'B 0110_1100") == (108, 8)
"""
if isinstance(rawinput, bool):
return _convert_bool(rawinput, bitwidth, signed)
elif isinstance(rawinput, numbers.Integral):
return _convert_int(rawinput, bitwidth, signed)
elif isinstance(rawinput, str):
return _convert_verilog_str(rawinput, bitwidth, signed)
else:
raise PyrtlError('error, the value provided is of an improper type, "%s"'
'proper types are bool, int, and string' % type(rawinput))
def _convert_bool(bool_val, bitwidth=None, signed=False):
if signed:
raise PyrtlError('error, booleans cannot be signed (convert to int first)')
num = int(bool_val)
if bitwidth is None:
bitwidth = 1
if bitwidth != 1:
raise PyrtlError('error, boolean has bitwidth not equal to 1')
return ValueBitwidthTuple(num, bitwidth)
def _convert_int(val, bitwidth=None, signed=False):
if val >= 0:
num = val
# infer bitwidth if it is not specified explicitly
min_bitwidth = len(bin(num)) - 2 # the -2 for the "0b" at the start of the string
if signed and val != 0:
min_bitwidth += 1 # extra bit needed for the zero
if bitwidth is None:
bitwidth = min_bitwidth
elif bitwidth < min_bitwidth:
raise PyrtlError('bitwidth specified is insufficient to represent constant')
else: # val is negative
if not signed and bitwidth is None:
raise PyrtlError('negative constants require either signed=True or specified bitwidth')
if bitwidth is None:
bitwidth = 1 if val == -1 else len(bin(~val)) - 1
if (val >> bitwidth - 1) != -1:
raise PyrtlError('insufficient bits for negative number')
num = val & ((1 << bitwidth) - 1) # result is a two's complement value
return ValueBitwidthTuple(num, bitwidth)
def _convert_verilog_str(val, bitwidth=None, signed=False):
if signed:
raise PyrtlError('error, "signed" option with verilog-style string constants not supported')
bases = {'b': 2, 'o': 8, 'd': 10, 'h': 16, 'x': 16}
passed_bitwidth = bitwidth
neg = False
if val.startswith('-'):
neg = True
val = val[1:]
split_string = val.lower().split("'")
if len(split_string) != 2:
raise PyrtlError('error, string not in verilog style format')
try:
bitwidth = int(split_string[0])
sval = split_string[1]
if sval[0] == 's':
raise PyrtlError('error, signed integers are not supported in verilog-style constants')
base = 10
if sval[0] in bases:
base = bases[sval[0]]
sval = sval[1:]
sval = sval.replace('_', '')
num = int(sval, base)
except (IndexError, ValueError):
raise PyrtlError('error, string not in verilog style format')
if neg and num:
if (num >> bitwidth - 1):
raise PyrtlError('error, insufficient bits for negative number')
num = (1 << bitwidth) - num
if passed_bitwidth and passed_bitwidth != bitwidth:
raise PyrtlError('error, bitwidth parameter of constant does not match'
' the bitwidth infered from the verilog style specification'
' (if bitwidth=None is used, pyrtl will determine the bitwidth from the'
' verilog-style constant specification)')
if num >> bitwidth != 0:
raise PyrtlError('specified bitwidth %d for verilog constant insufficient to store value %d'
% (bitwidth, num))
return ValueBitwidthTuple(num, bitwidth)
def get_stacks(*wires):
call_stack = getattr(wires[0], 'init_call_stack', None)
if not call_stack:
return ' No call info found for wires: use set_debug_mode() ' \
'to provide more information\n'
else:
return '\n'.join(str(wire) + ":\n" + get_stack(wire) for wire in wires)
def get_stack(wire):
if not isinstance(wire, WireVector):
raise PyrtlError('Only WireVectors can be traced')
call_stack = getattr(wire, 'init_call_stack', None)
if call_stack:
frames = ' '.join(frame for frame in call_stack[:-1])
return "Wire Traceback, most recent call last \n" + frames + "\n"
else:
return ' No call info found for wire: use set_debug_mode()'\
' to provide more information'
def _check_for_loop(block=None):
block = working_block(block)
logic_left = block.logic.copy()
wires_left = block.wirevector_subset(exclude=(Input, Const, Output, Register))
prev_logic_left = len(logic_left) + 1
while prev_logic_left > len(logic_left):
prev_logic_left = len(logic_left)
nets_to_remove = set() # bc it's not safe to mutate a set inside its own iterator
for net in logic_left:
if not any(n_wire in wires_left for n_wire in net.args):
nets_to_remove.add(net)
wires_left.difference_update(net.dests)
logic_left -= nets_to_remove
if 0 == len(logic_left):
return None
return wires_left, logic_left
def find_loop(block=None):
block = working_block(block)
block.sanity_check() # make sure that the block is sane first
result = _check_for_loop(block)
if not result:
return
wires_left, logic_left = result
import random
class _FilteringState(object):
def __init__(self, dst_w):
self.dst_w = dst_w
self.arg_num = -1
def dead_end():
# clean up after a wire is found to not be part of the loop
wires_left.discard(cur_item.dst_w)
current_wires.discard(cur_item.dst_w)
del checking_stack[-1]
# now making a map to quickly look up nets
dest_nets = {dest_w: net_ for net_ in logic_left for dest_w in net_.dests}
initial_w = random.sample(list(wires_left), 1)[0]
current_wires = set()
checking_stack = [_FilteringState(initial_w)]
# we don't use a recursive method as Python has a limited stack (default: 999 frames)
while len(checking_stack):
cur_item = checking_stack[-1]
if cur_item.arg_num == -1:
# first time testing this item
if cur_item.dst_w not in wires_left:
dead_end()
continue
current_wires.add(cur_item.dst_w)
cur_item.net = dest_nets[cur_item.dst_w]
if cur_item.net.op == 'r':
dead_end()
continue
cur_item.arg_num += 1 # go to the next item
if cur_item.arg_num == len(cur_item.net.args):
dead_end()
continue
next_wire = cur_item.net.args[cur_item.arg_num]
if next_wire not in current_wires:
current_wires.add(next_wire)
checking_stack.append(_FilteringState(next_wire))
else: # We have found the loop!!!!!
loop_info = []
for f_state in reversed(checking_stack):
loop_info.append(f_state)
if f_state.dst_w is next_wire:
break
else:
raise PyrtlError("Shouldn't get here! Couldn't figure out the loop")
return loop_info
raise PyrtlError("Error in detecting loop")
def find_and_print_loop(block=None):
loop_data = find_loop(block)
print_loop(loop_data)
return loop_data
def print_loop(loop_data):
if not loop_data:
print("No Loop Found")
else:
print("Loop found:")
print('\n'.join("{}".format(fs.net) for fs in loop_data))
# print '\n'.join("{} (dest wire: {})".format(fs.net, fs.dst_w) for fs in loop_info)
print("")
def _currently_in_jupyter_notebook():
""" Return true if running under Jupyter notebook, otherwise return False.
We want to check for more than just the presence of __IPYTHON__ because
that is present in both Jupyter notebooks and IPython terminals.
"""
try:
# get_ipython() is in the global namespace when ipython is started
shell = get_ipython().__class__.__name__
if shell == 'ZMQInteractiveShell':
return True # Jupyter notebook or qtconsole
elif shell == 'TerminalInteractiveShell':
return False # Terminal running IPython
else:
return False # Other type
except NameError:
return False # Probably standard Python interpreter
def _print_netlist_latex(netlist):
""" Print each net in netlist in a Latex array """
from IPython.display import display, Latex # pylint: disable=import-error
out = '\n\\begin{array}{ \\| c \\| c \\| l \\| }\n'
out += '\n\\hline\n'
out += '\\hline\n'.join(str(n) for n in netlist)
out += '\\hline\n\\end{array}\n'
display(Latex(out))
class _NetCount(object):
""" Helper class to track when to stop an iteration that depends on number of nets
Mainly useful for iterations that are for optimization
"""
def __init__(self, block=None):
self.block = working_block(block)
self.prev_nets = len(self.block.logic) * 1000
def shrank(self, block=None, percent_diff=0, abs_diff=1):
"""
Returns whether a block has fewer nets than before
:param Block block: block to check (if changed)
:param Number percent_diff: percentage difference threshold
:param int abs_diff: absolute difference threshold
:return: boolean
This function checks whether the change in the number of
nets is greater than the percentage and absolute difference
thresholds.
"""
if block is None:
block = self.block
cur_nets = len(block.logic)
net_goal = self.prev_nets * (1 - percent_diff) - abs_diff
less_nets = (cur_nets <= net_goal)
self.prev_nets = cur_nets
return less_nets
shrinking = shrank
# _ComponentMeta holds the component's name, bitwidth, and type. If the
# _ComponentMeta's type is None, then the default component_type should be used
# instead.
_ComponentMeta = collections.namedtuple('_ComponentMeta',
['name', 'bitwidth', 'type'])
def _make_component(component_meta: _ComponentMeta, block: Block, name: str,
component_type, component_value):
'''Determine the component's type, instantiate it, and set its value.'''
# Determine the component's actual type.
#
# If the _ComponentMeta specifies a type, then the component is a
# wire_struct or a wire_matrix. The _ComponentMeta's type must be used as
# the component's primary type, and the default_component_type becomes the
# component's concatenated_type.
#
# If the _ComponentMeta does not specify a type, the component uses the
# default component_type.
if component_meta.type is None:
actual_component_type = component_type
else:
actual_component_type = component_meta.type
component_name = ''
if len(name) > 0:
if isinstance(component_meta.name, str):
# wire_struct components are named with strings and printed with
# dots, like `struct.component`.
component_name = name + '.' + component_meta.name
else:
# wire_matrix components are numbered with integers and printed
# with brackets, like `matrix[0]`.
component_name = name + '[' + str(component_meta.name) + ']'
# The logic below always creates a new wire_struct, wire_matrix, or
# WireVector for each component. If the component_value already has the
# appropriate type and name, we could use the component_value directly and
# we don't need a new struct/matrix/Vector. Correctly detecting these
# opportunities is complicated, so we keep things simple for now.
#
# Components are always initialized with one concatenated component_value,
# which provides values for all its wires. This implies that component
# wire_structs and wire_matricies always call _split().
if hasattr(actual_component_type, '_is_wire_struct'):
# Make a wire_struct component. component_value may be None.
component_kwargs = {
actual_component_type._class_name: component_value
}
component = actual_component_type(
name=component_name,
block=block,
concatenated_type=component_type,
**component_kwargs)
elif hasattr(actual_component_type, '_is_wire_matrix'):
# Make a wire_matrix component. component_value may be None.
component = actual_component_type(
name=component_name,
block=block,
concatenated_type=component_type,
values=[component_value])
elif (isinstance(component_value, int)
and actual_component_type is WireVector):
# Special case: simplify the component type to Const.
component = Const(
bitwidth=component_meta.bitwidth, name=component_name,
block=block, val=component_value)
else:
# Make a WireVector component.
component = actual_component_type(
bitwidth=component_meta.bitwidth,
name=component_name,
block=block)
if component_value is not None:
component <<= component_value
return component
def _bitslice(value: int, start: int, end: int) -> int:
'''Slice an integer value bitwise, from start to end.'''
mask = (1 << (end - start)) - 1
return (value >> start) & mask
def _slice(block: Block,
schema: list[_ComponentMeta],
bitwidth: int,
component_type: type,
name: str,
concatenated,
components,
concatenated_value):
'''Slice ``concatenated`` into components.
``concatenated_value`` is the driver for ``concatenated``. Some
optimizations are possible by inspecting ``concatenated_value``, for
example we immediately slice Consts rather than generating slicing logic.
'''
if concatenated_value is not None and not isinstance(concatenated, Const):
concatenated <<= concatenated_value
end_index = bitwidth
for component_meta in schema:
if isinstance(concatenated_value, int):
# Special case: immediately slice Const values.
component_value = _bitslice(concatenated_value,
end_index - component_meta.bitwidth,
end_index)
else:
component_value = concatenated[
end_index - component_meta.bitwidth:end_index]
end_index -= component_meta.bitwidth
component = _make_component(component_meta=component_meta, block=block,
name=name, component_type=component_type,
component_value=component_value)
components[component_meta.name] = component
def _concatenate(block: Block,
schema: list[_ComponentMeta],
component_type: type,
name: str,
concatenated,
components,
component_map):
'''Concatenate components from ``component_map`` to ``concatenated``.'''
all_components = []
for component_meta in schema:
component_value = component_map[component_meta.name]
component = _make_component(component_meta=component_meta, block=block,
name=name, component_type=component_type,
component_value=component_value)
components[component_meta.name] = component
all_components.append(component)
concatenated <<= concat(*all_components)
[docs]
def wire_struct(wire_struct_spec):
'''Decorator that assigns names to ``WireVector`` slices.
``@wire_struct`` assigns names to *non-overlapping* ``WireVector`` slices.
Suppose we have an 8-bit wide ``WireVector`` called ``byte``. We can refer
to all 8 bits with the name ``byte``, but ``@wire_struct`` lets us refer to
slices by name, for example we could name the high 4 bits ``byte.high`` and
the low 4 bits ``byte.low``. Without ``@wire_struct``, we would refer to
these slices as ``byte[4:8]`` and ``byte[0:4]``, which are prone to
off-by-one errors and harder to read.
The example ``Byte`` ``@wire_struct`` can be defined as::
@wire_struct
class Byte:
high: 4 # 'high' is name for the 4 most significant bits.
low: 4 # 'low' is name for the 4 least significant bits.
------------
Construction
------------
Once a ``@wire_struct`` class is defined, it can be instantiated by
providing drivers for all of its wires. This can be done in two ways:
1. Provide a driver for *each* component wire, for example::
byte = Byte(high=0xA, low=0xB)
Note how the component names (``high``, ``low``) are used as keyword
args for the constructor. Drivers must be provided for *all* components.
2. Provide a driver for the entire ``@wire_struct``, for example::
byte = Byte(Byte=0xAB)
Note how the class name (``Byte``) is used as a keyword arg for the
constructor.
----------------
Accessing Slices
----------------
After instantiating a ``@wire_struct``, the instance functions as a
``WireVector`` containing all the wires. For example, ``byte`` functions as
a ``WireVector`` with bitwidth 8::
byte = Byte(Byte=0xAB)
print(byte.bitwidth) # Prints 8.
The named slice can be accessed through the ``.`` operator
(``__getattr__``), for example ``byte.high`` and ``byte.low``, which both
function as ``WireVector`` with bitwidth 4::
byte = Byte(Byte=0xAB)
print(byte.high.bitwidth) # Prints 4.
print(byte.low.bitwidth) # Prints 4.
Both the instance and the slices are first-class ``WireVector``, so they
can be manipulated with all the usual PyRTL operators.
.. NOTE::
``len(byte)`` returns the number of components in the ``@wire_struct``
(2), not the total bitwidth (8 == 4 + 4). To get the total bitwidth,
use ``byte.bitwidth`` or ``len(as_wires(byte))``.
------
Naming
------
A ``@wire_struct`` can be assigned a name in the usual way::
byte = Byte(name='b', high=0xC, low=0xD)
byte = Byte(name='b', Byte=0xCD)
When a ``@wire_struct`` is assigned a name (``b``), its components will be
assigned dotted names (``b.high``, ``b.low``)::
print(byte.high.name) # Prints 'b.high'.
print(byte.low.name) # Prints 'b.low'.
.. WARNING::
All ``@wire_struct`` names are only set during construction. You can
later rename a ``@wire_struct`` or its components, but those changes
are local, and will not propagate to other ``@wire_struct`` components.
Renaming a ``@wire_struct`` or its components is strongly discouraged.
-----------
Composition
-----------
``@wire_struct`` can be composed with itself, and with ``wire_matrix``. For
example, we can define a ``Pixel`` that contains three ``Byte``::
@wire_struct
class Pixel:
red: Byte
green: Byte
blue: Byte
Drivers must be specified for all components, but they can be specified at
any level. All these examples construct an equivalent ``@wire_struct``::
pixel = Pixel(Pixel=0xABCDEF)
pixel = Pixel(red=0xAB, green=0xCD, blue=0xEF)
pixel = Pixel(red=Byte(high=0xA, low=0xB), green=0xCD, blue=0xEF)
pixel = Pixel(red=Byte(high=0xA, low=0xB),
green=Byte(high=0xC, low=0xD),
blue=0xEF)
Hierarchical ``@wire_struct`` components are accessed by composing ``.``
operators::
pixel
pixel.red
pixel.red.high
pixel.red.low
pixel.green
pixel.green.high
pixel.green.low
pixel.blue
pixel.blue.high
pixel.blue.low
``@wire_struct`` can be composed with ``wire_matrix``::
Word = wire_matrix(component_schema=8, size=4)
@wire_struct
class CacheLine:
address: Word
data: Word
valid: 1
cache_line = CacheLine(address=0x01234567, data=0x89ABCDEF, valid=1)
Leaf-level components can be accessed by combining the ``.`` and ``[]``
operators, for example ``cache_line.address[3]``.
-----
Types
-----
You can change the type of a ``@wire_struct``'s components to a
``WireVector`` subclass like :py:class:`Input` or :py:class:`Output` with
the ``component_type`` constructor argument::
# Generates Outputs named ``output_byte.low`` and ``output_byte.high``.
output_byte = Byte(name='output_byte', component_type=pyrtl.Output,
Byte=0xCD)
You can also change the type of the ``@wire_struct`` itself with the
``concatenated_type`` cnstructor argument::
# Generates an Input named ``input_byte``.
input_byte = Byte(name='input_byte', concatenated_type=pyrtl.Input)
.. NOTE::
No values are specified for ``input_byte`` because its value is not
known until simulation time.
'''
# Convert the decorated class' annotations (dict of attr_name: attr_value)
# to a list of _ComponentMetas.
#
# dict iteration order is guaranteed to be insertion order in Python 3.7+.
schema = []
for component_name, component_bitwidth in (
wire_struct_spec.__annotations__.items()):
if isinstance(component_bitwidth, int):
# An ordinary component ("foo: 4") that should use the default
# component_type.
schema.append(_ComponentMeta(
name=component_name, bitwidth=component_bitwidth,
type=None))
else:
# A nested component ("bar: Byte") that must use the nested
# component's type.
schema.append(_ComponentMeta(
name=component_name, bitwidth=component_bitwidth._bitwidth,
type=component_bitwidth))
total_bitwidth = sum([component.bitwidth for component in schema])
# Name of the decorated class.
class_name = wire_struct_spec.__name__
class _WireStruct(WrappedWireVector):
'''``wire_struct`` implementation: Concatenate or slice ``WireVector``.
``wire_struct`` works by either concatenating component ``WireVector``
to create the ``wire_struct``'s full value, *or* slicing a
``wire_struct``s value to create component ``WireVectors``. A
``wire_struct`` can only concatenate or slice, not both. The decision
to concatenate or slice is made in __init__.
'''
_bitwidth = total_bitwidth
_class_name = class_name
_is_wire_struct = True
def __init__(self, name='', block=None,
concatenated_type=WireVector, component_type=WireVector,
**kwargs):
'''Concatenate or slice ``WireVector`` components.
:param str name: The name of the concatenated wire. Must be unique.
If none is provided, one will be autogenerated. If a name is
provided, components will be assigned names of the form
"{name}.{component_name}".
:param Block block: The block containing the concatenated and
component wires. Defaults to the working block.
:param type concatenated_type: Type for the concatenated
``WireVector``.
:param type component_type: Type for each component.
The remaining keyword args specify values for all wires. If the
concatenated value is provided, its value must be provided with the
keyword arg matching the decorated class name. For example, if the
decorated class is::
@wire_struct
class Byte:
high: 4 # high is the 4 most significant bits.
low: 4 # low is the 4 least significant bits.
then the concatenated value must be provided like this::
byte = Byte(Byte=0xAB)
And if the component values are provided instead, their values are
set by keyword args matching the component names::
byte = Byte(low=0xA, high=0xB)
'''
# The concatenated WireVector contains all the _WireStruct's wires.
# WrappedWireVector (base class) will forward all attribute and
# method accesses on this _WireStruct to the concatenated
# WireVector.
if ((class_name in kwargs and isinstance(kwargs[class_name], int)
and concatenated_type is WireVector)):
# Special case: simplify the concatenated type to Const.
concatenated = Const(
bitwidth=self._bitwidth, name=name, block=block,
val=kwargs[class_name])
else:
concatenated = concatenated_type(
bitwidth=self._bitwidth, name=name, block=block)
super().__init__(wire=concatenated)
# self._components maps from component name to each component's
# WireVector.
components = {}
self.__dict__['_components'] = components
# Handle Input and Register special cases.
if concatenated_type is Input or concatenated_type is Register:
kwargs = {class_name: None}
elif component_type is Input or component_type is Register:
kwargs = {component_meta.name: None
for component_meta in schema}
if class_name in kwargs:
# Check for unused kwargs.
for component_name in kwargs:
if component_name != class_name:
raise PyrtlError(
'Do not pass additional kwargs to @wire_struct '
f'when slicing. ("{class_name}" was passed so '
f'don\'t pass "{component_name}")')
# Concatenated value was provided. Slice it into components.
_slice(block=block, schema=schema, bitwidth=self._bitwidth,
component_type=component_type, name=name,
concatenated=concatenated, components=components,
concatenated_value=kwargs[class_name])
else:
# Component values were provided; concatenate them.
# Check that values were provided for all components.
expected_component_names = (
[component_meta.name for component_meta in schema])
for expected_component_name in expected_component_names:
if expected_component_name not in kwargs:
raise PyrtlError(
'You must provide kwargs for all @wire_struct '
'components when concatenating (missing kwarg '
f'"{expected_component_name}")')
# Check for unused kwargs.
for component_name in kwargs:
if component_name not in expected_component_names:
raise PyrtlError(
'Do not pass additional kwargs to @wire_struct '
'when concatenating (don\'t pass '
f'"{component_name}")')
_concatenate(block=block, schema=schema,
component_type=component_type, name=name,
concatenated=concatenated, components=components,
component_map=kwargs)
def __getattr__(self, component_name: str):
'''Retrieve a component by name.
Components are concatenated to form the concatenated
``WireVector``, or sliced from the concatenated ``WireVector``.
:param component_name: The name of the component wire.
'''
components = self.__dict__['_components']
if component_name in components:
return components[component_name]
return super().__getattr__(component_name)
def __len__(self):
components = self.__dict__['_components']
return len(components)
return _WireStruct
[docs]
def wire_matrix(component_schema, size: int):
'''Returns a class that assigns numbered indices to ``WireVector`` slices.
``wire_matrix`` assigns numbered indices to *non-overlapping*
``WireVector`` slices. ``wire_matrix`` is very similar to
:py:func:`wire_struct`, so read :py:func:`wire_struct`'s documentation
first.
An example 32-bit ``Word`` ``wire_matrix``, which represents a group of
four bytes, can be defined as::
Word = wire_matrix(component_schema=8, size=4)
.. NOTE::
``wire_matrix`` returns a class, like ``namedtuple``.
------------
Construction
------------
Once a ``wire_matrix`` class is defined, it can be instantiated by
providing drivers for all of its wires. This can be done in two ways::
# Provide a driver for each component, most significant bits first.
word = Word(values=[0x89, 0xAB, 0xCD, 0xEF])
# Provide a driver for all components.
word = Word(values=[0x89ABCDEF])
.. NOTE::
When specifying drivers for each component, the most significant bits
are specified first.
After instantiating a ``wire_matrix``, regardless of how it was
constructed, the instance functions as a ``WireVector`` containing all the
wires, so ``word`` functions as a ``WireVector`` with bitwidth 32. The
named slice can be accessed with square brackets (``__getitem__``), for
example ``word[0]`` and ``word[3]``, which both function as ``WireVector``
with bitwidth 8. ``word[0]`` refers to the most significant byte, and
``word[3]`` refers to the least significant byte. Both the instance and the
slices are first-class ``WireVector``, so they can be manipulated with all
the usual PyRTL operators.
------
Naming
------
A ``wire_matrix`` can be assigned a name in the usual way::
# The whole Word is named 'w', so the components will have names
# w[0], w[1], ...
word = Word(name='w', values=[0x89, 0xAB, 0xCD, 0xEF])
word = Word(name='w', values=[0x89ABCDEF])
-----------
Composition
-----------
``wire_matrix`` can be composed with itself and ``@wire_struct``. For
example, we can define some multi-dimensional byte arrays::
Array1D = wire_matrix(component_schema=8, size=2)
Array2D = wire_matrix(component_schema=Array1D, size=2)
Drivers must be specified for all components, but they can be specified at
any level. All these examples construct an equivalent ``wire_matrix``::
array_2d = Array2D(values=[0x89AB, 0xCDEF])
array_2d = Array2D(values=[Array1D(values=[0x89, 0xAB]),
0xCDEF])
array_2d = Array2D(values=[Array1D(values=[0x89, 0xAB]),
Array1D(values=[0xCD, 0xEF])])
----------------
Accessing Slices
----------------
Hierarchical components are accessed by composing ``[]`` operators, for
example::
print(array_2d[0][0].bitwidth) # Prints 8.
print(array_2d[0][1].bitwidth) # Prints 8.
When ``wire_matrix`` is composed with ``@wire_struct``, components can be
accessed by combining the ``[]`` and ``.`` operators::
@wire_struct
class Byte:
high: 4
low: 4
Array1D = wire_matrix(component_schema=Byte, size=2)
array_1d = Array1D(values=[0xAB, 0xCD])
print(array_1d[0].high.bitwidth) # Prints 4.
.. NOTE::
``len(array_1d)`` returns the number of components in the
``wire_matrix`` (2), not the total bitwidth (16 == 2 * 8). To get the
total bitwidth, use ``array_1d.bitwidth`` or
``len(as_wires(array_1d))``.
-----
Types
-----
You can change the type of a ``wire_matrix``'s components with the
``component_type`` constructor argument::
# Generates Outputs named ``output_word[0]``, ``output_word[1]``, ...
word = Word(name='output_word',
component_type=pyrtl.Output,
values=[0x89ABCDEF])
You can change the type of the ``wire_matrix`` itself with the
``concatenated_type`` cnstructor argument::
# Generates an Input named ``input_word``.
word = Word(name='input_word', concatenated_type=pyrtl.Input)
.. NOTE::
No values are specified for ``input_word`` because its value is not
known until simulation time.
'''
# Determine each component's bitwidth.
if ((hasattr(component_schema, '_is_wire_struct')
or hasattr(component_schema, '_is_wire_matrix'))):
component_bitwidth = component_schema._bitwidth
else:
component_bitwidth = component_schema
component_schema = None
class _WireMatrix(WrappedWireVector):
_component_bitwidth = component_bitwidth
_component_schema = component_schema
_size = size
_bitwidth = component_bitwidth * size
_is_wire_matrix = True
def __init__(self, name: str = '', block: Block = None,
concatenated_type=WireVector, component_type=WireVector,
values: list = []):
# The concatenated WireVector contains all the _WireMatrix's wires.
# WrappedWireVector (base class) will forward all attribute and
# method accesses on this _WireMatrix to the concatenated
# WireVector.
if ((len(values) == 1 and isinstance(values[0], int)
and concatenated_type is WireVector)):
# Special case: simplify the concatenated type to Const.
concatenated = Const(
bitwidth=self._bitwidth, name=name, block=block,
val=values[0])
else:
concatenated = concatenated_type(
bitwidth=self._bitwidth, name=name, block=block)
super().__init__(wire=concatenated)
schema = []
for component_name in range(self._size):
schema.append(_ComponentMeta(
name=component_name, bitwidth=self._component_bitwidth,
type=component_schema))
# Handle Input and Register special cases.
if concatenated_type is Input or concatenated_type is Register:
values = [None]
elif component_type is Input or component_type is Register:
values = [None for _ in range(self._size)]
self._components = [None for i in range(len(schema))]
if len(values) == 1:
# Concatenated value was provided. Slice it into components.
_slice(block=block, schema=schema, bitwidth=self._bitwidth,
component_type=component_type, name=name,
concatenated=concatenated, components=self._components,
concatenated_value=values[0])
else:
if len(values) != len(schema):
raise PyrtlError(
'wire_matrix constructor expects 1 value to slice, or '
f'{len(schema)} values to concatenate (received '
f'{len(values)} values)')
# Component values were provided; concatenate them.
_concatenate(block=block, schema=schema,
component_type=component_type, name=name,
concatenated=concatenated,
components=self._components, component_map=values)
def __getitem__(self, key):
return self._components[key]
def __len__(self):
return len(self._components)
return _WireMatrix